From fe1eb8ca5f57fcd7c566adfc01fa1266349b72f3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 3 Apr 2023 21:57:19 -0700 Subject: [PATCH] requests wrapper (#2367) --- .../agents/toolkits/examples/json.ipynb | 2 +- .../agents/toolkits/examples/openapi.ipynb | 4 +- .../agents/tools/examples/requests.ipynb | 4 +- .../agents/agent_toolkits/openapi/toolkit.py | 8 +- langchain/agents/load_tools.py | 12 +- langchain/chains/api/base.py | 6 +- langchain/chains/llm_requests.py | 6 +- langchain/requests.py | 121 +++++++++++++++--- langchain/tools/requests/tool.py | 4 +- langchain/utilities/__init__.py | 4 +- tests/unit_tests/chains/test_api.py | 4 +- 11 files changed, 128 insertions(+), 47 deletions(-) diff --git a/docs/modules/agents/toolkits/examples/json.ipynb b/docs/modules/agents/toolkits/examples/json.ipynb index 9bec32e7..361bccd7 100644 --- a/docs/modules/agents/toolkits/examples/json.ipynb +++ b/docs/modules/agents/toolkits/examples/json.ipynb @@ -41,7 +41,7 @@ "from langchain.agents.agent_toolkits import JsonToolkit\n", "from langchain.chains import LLMChain\n", "from langchain.llms.openai import OpenAI\n", - "from langchain.requests import RequestsWrapper\n", + "from langchain.requests import TextRequestsWrapper\n", "from langchain.tools.json.tool import JsonSpec" ] }, diff --git a/docs/modules/agents/toolkits/examples/openapi.ipynb b/docs/modules/agents/toolkits/examples/openapi.ipynb index f7e12df3..9fe24136 100644 --- a/docs/modules/agents/toolkits/examples/openapi.ipynb +++ b/docs/modules/agents/toolkits/examples/openapi.ipynb @@ -35,7 +35,7 @@ "from langchain.agents import create_openapi_agent\n", "from langchain.agents.agent_toolkits import OpenAPIToolkit\n", "from langchain.llms.openai import OpenAI\n", - "from langchain.requests import RequestsWrapper\n", + "from langchain.requests import TextRequestsWrapper\n", "from langchain.tools.json.tool import JsonSpec" ] }, @@ -54,7 +54,7 @@ "headers = {\n", " \"Authorization\": f\"Bearer {os.getenv('OPENAI_API_KEY')}\"\n", "}\n", - "requests_wrapper=RequestsWrapper(headers=headers)\n", + "requests_wrapper=TextRequestsWrapper(headers=headers)\n", "openapi_toolkit = OpenAPIToolkit.from_llm(OpenAI(temperature=0), json_spec, requests_wrapper, verbose=True)\n", "openapi_agent_executor = create_openapi_agent(\n", " llm=OpenAI(temperature=0),\n", diff --git a/docs/modules/agents/tools/examples/requests.ipynb b/docs/modules/agents/tools/examples/requests.ipynb index 7096f138..a2b04e62 100644 --- a/docs/modules/agents/tools/examples/requests.ipynb +++ b/docs/modules/agents/tools/examples/requests.ipynb @@ -17,7 +17,7 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.utilities import RequestsWrapper" + "from langchain.utilities import TextRequestsWrapper" ] }, { @@ -27,7 +27,7 @@ "metadata": {}, "outputs": [], "source": [ - "requests = RequestsWrapper()" + "requests = TextRequestsWrapper()" ] }, { diff --git a/langchain/agents/agent_toolkits/openapi/toolkit.py b/langchain/agents/agent_toolkits/openapi/toolkit.py index 25d48634..3ae16526 100644 --- a/langchain/agents/agent_toolkits/openapi/toolkit.py +++ b/langchain/agents/agent_toolkits/openapi/toolkit.py @@ -10,7 +10,7 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION from langchain.agents.tools import Tool from langchain.llms.base import BaseLLM -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from langchain.tools import BaseTool from langchain.tools.json.tool import JsonSpec from langchain.tools.requests.tool import ( @@ -25,7 +25,7 @@ from langchain.tools.requests.tool import ( class RequestsToolkit(BaseToolkit): """Toolkit for making requests.""" - requests_wrapper: RequestsWrapper + requests_wrapper: TextRequestsWrapper def get_tools(self) -> List[BaseTool]: """Return a list of tools.""" @@ -42,7 +42,7 @@ class OpenAPIToolkit(BaseToolkit): """Toolkit for interacting with a OpenAPI api.""" json_agent: AgentExecutor - requests_wrapper: RequestsWrapper + requests_wrapper: TextRequestsWrapper def get_tools(self) -> List[BaseTool]: """Get the tools in the toolkit.""" @@ -59,7 +59,7 @@ class OpenAPIToolkit(BaseToolkit): cls, llm: BaseLLM, json_spec: JsonSpec, - requests_wrapper: RequestsWrapper, + requests_wrapper: TextRequestsWrapper, **kwargs: Any, ) -> OpenAPIToolkit: """Create json agent from llm, then initialize.""" diff --git a/langchain/agents/load_tools.py b/langchain/agents/load_tools.py index c20322fb..bcfa945a 100644 --- a/langchain/agents/load_tools.py +++ b/langchain/agents/load_tools.py @@ -10,7 +10,7 @@ from langchain.chains.api.base import APIChain from langchain.chains.llm_math.base import LLMMathChain from langchain.chains.pal.base import PALChain from langchain.llms.base import BaseLLM -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from langchain.tools.base import BaseTool from langchain.tools.bing_search.tool import BingSearchRun from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun @@ -42,23 +42,23 @@ def _get_python_repl() -> BaseTool: def _get_tools_requests_get() -> BaseTool: - return RequestsGetTool(requests_wrapper=RequestsWrapper()) + return RequestsGetTool(requests_wrapper=TextRequestsWrapper()) def _get_tools_requests_post() -> BaseTool: - return RequestsPostTool(requests_wrapper=RequestsWrapper()) + return RequestsPostTool(requests_wrapper=TextRequestsWrapper()) def _get_tools_requests_patch() -> BaseTool: - return RequestsPatchTool(requests_wrapper=RequestsWrapper()) + return RequestsPatchTool(requests_wrapper=TextRequestsWrapper()) def _get_tools_requests_put() -> BaseTool: - return RequestsPutTool(requests_wrapper=RequestsWrapper()) + return RequestsPutTool(requests_wrapper=TextRequestsWrapper()) def _get_tools_requests_delete() -> BaseTool: - return RequestsDeleteTool(requests_wrapper=RequestsWrapper()) + return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper()) def _get_terminal() -> BaseTool: diff --git a/langchain/chains/api/base.py b/langchain/chains/api/base.py index 5cbded4e..27e093b0 100644 --- a/langchain/chains/api/base.py +++ b/langchain/chains/api/base.py @@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT from langchain.chains.base import Chain from langchain.chains.llm import LLMChain from langchain.prompts import BasePromptTemplate -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from langchain.schema import BaseLanguageModel @@ -18,7 +18,7 @@ class APIChain(Chain, BaseModel): api_request_chain: LLMChain api_answer_chain: LLMChain - requests_wrapper: RequestsWrapper = Field(exclude=True) + requests_wrapper: TextRequestsWrapper = Field(exclude=True) api_docs: str question_key: str = "question" #: :meta private: output_key: str = "output" #: :meta private: @@ -93,7 +93,7 @@ class APIChain(Chain, BaseModel): ) -> APIChain: """Load chain from just an LLM and the api docs.""" get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt) - requests_wrapper = RequestsWrapper(headers=headers) + requests_wrapper = TextRequestsWrapper(headers=headers) get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt) return cls( api_request_chain=get_request_chain, diff --git a/langchain/chains/llm_requests.py b/langchain/chains/llm_requests.py index 2374a42e..f3f7fb31 100644 --- a/langchain/chains/llm_requests.py +++ b/langchain/chains/llm_requests.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from langchain.chains import LLMChain from langchain.chains.base import Chain -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper DEFAULT_HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501 @@ -18,8 +18,8 @@ class LLMRequestsChain(Chain, BaseModel): """Chain that hits a URL and then uses an LLM to parse results.""" llm_chain: LLMChain - requests_wrapper: RequestsWrapper = Field( - default_factory=RequestsWrapper, exclude=True + requests_wrapper: TextRequestsWrapper = Field( + default_factory=TextRequestsWrapper, exclude=True ) text_length: int = 8000 requests_key: str = "requests_result" #: :meta private: diff --git a/langchain/requests.py b/langchain/requests.py index ac5091f0..97f29af6 100644 --- a/langchain/requests.py +++ b/langchain/requests.py @@ -6,8 +6,12 @@ import requests from pydantic import BaseModel, Extra -class RequestsWrapper(BaseModel): - """Lightweight wrapper around requests library.""" +class Requests(BaseModel): + """Wrapper around requests to handle auth and async. + + The main purpose of this wrapper is to handle authentication (by saving + headers) and enable easy async methods on the same base object. + """ headers: Optional[Dict[str, str]] = None aiosession: Optional[aiohttp.ClientSession] = None @@ -18,56 +22,133 @@ class RequestsWrapper(BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def get(self, url: str, **kwargs: Any) -> str: + def get(self, url: str, **kwargs: Any) -> requests.Response: """GET the URL and return the text.""" - return requests.get(url, headers=self.headers, **kwargs).text + return requests.get(url, headers=self.headers, **kwargs) - def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """POST to the URL and return the text.""" - return requests.post(url, json=data, headers=self.headers, **kwargs).text + return requests.post(url, json=data, headers=self.headers, **kwargs) - def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """PATCH the URL and return the text.""" - return requests.patch(url, json=data, headers=self.headers, **kwargs).text + return requests.patch(url, json=data, headers=self.headers, **kwargs) - def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """PUT the URL and return the text.""" - return requests.put(url, json=data, headers=self.headers, **kwargs).text + return requests.put(url, json=data, headers=self.headers, **kwargs) - def delete(self, url: str, **kwargs: Any) -> str: + def delete(self, url: str, **kwargs: Any) -> requests.Response: """DELETE the URL and return the text.""" - return requests.delete(url, headers=self.headers, **kwargs).text + return requests.delete(url, headers=self.headers, **kwargs) - async def _arequest(self, method: str, url: str, **kwargs: Any) -> str: + async def _arequest( + self, method: str, url: str, **kwargs: Any + ) -> aiohttp.ClientResponse: """Make an async request.""" if not self.aiosession: async with aiohttp.ClientSession() as session: async with session.request( method, url, headers=self.headers, **kwargs ) as response: - return await response.text() + return response else: async with self.aiosession.request( method, url, headers=self.headers, **kwargs ) as response: - return await response.text() + return response - async def aget(self, url: str, **kwargs: Any) -> str: + async def aget(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: """GET the URL and return the text asynchronously.""" return await self._arequest("GET", url, **kwargs) - async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + async def apost( + self, url: str, data: Dict[str, Any], **kwargs: Any + ) -> aiohttp.ClientResponse: """POST to the URL and return the text asynchronously.""" return await self._arequest("POST", url, json=data, **kwargs) - async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + async def apatch( + self, url: str, data: Dict[str, Any], **kwargs: Any + ) -> aiohttp.ClientResponse: """PATCH the URL and return the text asynchronously.""" return await self._arequest("PATCH", url, json=data, **kwargs) - async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + async def aput( + self, url: str, data: Dict[str, Any], **kwargs: Any + ) -> aiohttp.ClientResponse: """PUT the URL and return the text asynchronously.""" return await self._arequest("PUT", url, json=data, **kwargs) - async def adelete(self, url: str, **kwargs: Any) -> str: + async def adelete(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse: """DELETE the URL and return the text asynchronously.""" return await self._arequest("DELETE", url, **kwargs) + + +class TextRequestsWrapper(BaseModel): + """Lightweight wrapper around requests library. + + The main purpose of this wrapper is to always return a text output. + """ + + headers: Optional[Dict[str, str]] = None + aiosession: Optional[aiohttp.ClientSession] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + arbitrary_types_allowed = True + + @property + def requests(self) -> Requests: + return Requests(headers=self.headers, aiosession=self.aiosession) + + def get(self, url: str, **kwargs: Any) -> str: + """GET the URL and return the text.""" + return self.requests.get(url, **kwargs).text + + def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """POST to the URL and return the text.""" + return self.requests.post(url, json=data, headers=self.headers, **kwargs).text + + def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """PATCH the URL and return the text.""" + return self.requests.patch(url, json=data, headers=self.headers, **kwargs).text + + def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """PUT the URL and return the text.""" + return self.requests.put(url, json=data, headers=self.headers, **kwargs).text + + def delete(self, url: str, **kwargs: Any) -> str: + """DELETE the URL and return the text.""" + return self.requests.delete(url, headers=self.headers, **kwargs).text + + async def aget(self, url: str, **kwargs: Any) -> str: + """GET the URL and return the text asynchronously.""" + response = await self.requests.aget(url, **kwargs) + return await response.text() + + async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """POST to the URL and return the text asynchronously.""" + response = await self.requests.apost(url, data, **kwargs) + return await response.text() + + async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """PATCH the URL and return the text asynchronously.""" + response = await self.requests.apatch(url, data, **kwargs) + return await response.text() + + async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str: + """PUT the URL and return the text asynchronously.""" + response = await self.requests.aput(url, data, **kwargs) + return await response.text() + + async def adelete(self, url: str, **kwargs: Any) -> str: + """DELETE the URL and return the text asynchronously.""" + response = await self.requests.adelete(url, **kwargs) + return await response.text() + + +# For backwards compatibility +RequestsWrapper = TextRequestsWrapper diff --git a/langchain/tools/requests/tool.py b/langchain/tools/requests/tool.py index 7802f598..aca09b07 100644 --- a/langchain/tools/requests/tool.py +++ b/langchain/tools/requests/tool.py @@ -5,7 +5,7 @@ from typing import Any, Dict from pydantic import BaseModel -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from langchain.tools.base import BaseTool @@ -17,7 +17,7 @@ def _parse_input(text: str) -> Dict[str, Any]: class BaseRequestsTool(BaseModel): """Base class for requests tools.""" - requests_wrapper: RequestsWrapper + requests_wrapper: TextRequestsWrapper class RequestsGetTool(BaseRequestsTool, BaseTool): diff --git a/langchain/utilities/__init__.py b/langchain/utilities/__init__.py index c9822364..1d83373b 100644 --- a/langchain/utilities/__init__.py +++ b/langchain/utilities/__init__.py @@ -1,6 +1,6 @@ """General utilities.""" from langchain.python import PythonREPL -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from langchain.utilities.apify import ApifyWrapper from langchain.utilities.bash import BashProcess from langchain.utilities.bing_search import BingSearchAPIWrapper @@ -15,7 +15,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper __all__ = [ "ApifyWrapper", "BashProcess", - "RequestsWrapper", + "TextRequestsWrapper", "PythonREPL", "GoogleSearchAPIWrapper", "GoogleSerperAPIWrapper", diff --git a/tests/unit_tests/chains/test_api.py b/tests/unit_tests/chains/test_api.py index 757960fb..f5d276f6 100644 --- a/tests/unit_tests/chains/test_api.py +++ b/tests/unit_tests/chains/test_api.py @@ -8,11 +8,11 @@ import pytest from langchain import LLMChain from langchain.chains.api.base import APIChain from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT -from langchain.requests import RequestsWrapper +from langchain.requests import TextRequestsWrapper from tests.unit_tests.llms.fake_llm import FakeLLM -class FakeRequestsChain(RequestsWrapper): +class FakeRequestsChain(TextRequestsWrapper): """Fake requests chain just for testing purposes.""" output: str