Strip surrounding quotes from requests tool URLs. (#3563)

Often an LLM will output a requests tool input argument surrounded by
single quotes. This triggers an exception in the requests library. Here,
we add a simple clean url function that strips any leading and trailing
single and double quotes before passing the URL to the underlying
requests library.

Co-authored-by: James Brotchie <brotchie@google.com>
This commit is contained in:
James Brotchie 2023-04-25 21:20:26 -07:00 committed by GitHub
parent f4829025fe
commit 5fdaa95e06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,11 @@ def _parse_input(text: str) -> Dict[str, Any]:
return json.loads(text) return json.loads(text)
def _clean_url(url: str) -> str:
"""Strips quotes from the url."""
return url.strip("\"'")
class BaseRequestsTool(BaseModel): class BaseRequestsTool(BaseModel):
"""Base class for requests tools.""" """Base class for requests tools."""
@ -28,11 +33,11 @@ class RequestsGetTool(BaseRequestsTool, BaseTool):
def _run(self, url: str) -> str: def _run(self, url: str) -> str:
"""Run the tool.""" """Run the tool."""
return self.requests_wrapper.get(url) return self.requests_wrapper.get(_clean_url(url))
async def _arun(self, url: str) -> str: async def _arun(self, url: str) -> str:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
return await self.requests_wrapper.aget(url) return await self.requests_wrapper.aget(_clean_url(url))
class RequestsPostTool(BaseRequestsTool, BaseTool): class RequestsPostTool(BaseRequestsTool, BaseTool):
@ -51,7 +56,7 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return self.requests_wrapper.post(data["url"], data["data"]) return self.requests_wrapper.post(_clean_url(data["url"]), data["data"])
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -59,7 +64,9 @@ class RequestsPostTool(BaseRequestsTool, BaseTool):
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return await self.requests_wrapper.apost(data["url"], data["data"]) return await self.requests_wrapper.apost(
_clean_url(data["url"]), data["data"]
)
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -80,7 +87,7 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return self.requests_wrapper.patch(data["url"], data["data"]) return self.requests_wrapper.patch(_clean_url(data["url"]), data["data"])
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -88,7 +95,9 @@ class RequestsPatchTool(BaseRequestsTool, BaseTool):
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return await self.requests_wrapper.apatch(data["url"], data["data"]) return await self.requests_wrapper.apatch(
_clean_url(data["url"]), data["data"]
)
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -109,7 +118,7 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
"""Run the tool.""" """Run the tool."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return self.requests_wrapper.put(data["url"], data["data"]) return self.requests_wrapper.put(_clean_url(data["url"]), data["data"])
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -117,7 +126,9 @@ class RequestsPutTool(BaseRequestsTool, BaseTool):
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
try: try:
data = _parse_input(text) data = _parse_input(text)
return await self.requests_wrapper.aput(data["url"], data["data"]) return await self.requests_wrapper.aput(
_clean_url(data["url"]), data["data"]
)
except Exception as e: except Exception as e:
return repr(e) return repr(e)
@ -130,8 +141,8 @@ class RequestsDeleteTool(BaseRequestsTool, BaseTool):
def _run(self, url: str) -> str: def _run(self, url: str) -> str:
"""Run the tool.""" """Run the tool."""
return self.requests_wrapper.delete(url) return self.requests_wrapper.delete(_clean_url(url))
async def _arun(self, url: str) -> str: async def _arun(self, url: str) -> str:
"""Run the tool asynchronously.""" """Run the tool asynchronously."""
return await self.requests_wrapper.adelete(url) return await self.requests_wrapper.adelete(_clean_url(url))