requests wrapper (#2367)

This commit is contained in:
Harrison Chase 2023-04-03 21:57:19 -07:00 committed by GitHub
parent 10dab053b4
commit fe1eb8ca5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 128 additions and 47 deletions

View File

@ -41,7 +41,7 @@
"from langchain.agents.agent_toolkits import JsonToolkit\n",
"from langchain.chains import LLMChain\n",
"from langchain.llms.openai import OpenAI\n",
"from langchain.requests import RequestsWrapper\n",
"from langchain.requests import TextRequestsWrapper\n",
"from langchain.tools.json.tool import JsonSpec"
]
},

View File

@ -35,7 +35,7 @@
"from langchain.agents import create_openapi_agent\n",
"from langchain.agents.agent_toolkits import OpenAPIToolkit\n",
"from langchain.llms.openai import OpenAI\n",
"from langchain.requests import RequestsWrapper\n",
"from langchain.requests import TextRequestsWrapper\n",
"from langchain.tools.json.tool import JsonSpec"
]
},
@ -54,7 +54,7 @@
"headers = {\n",
" \"Authorization\": f\"Bearer {os.getenv('OPENAI_API_KEY')}\"\n",
"}\n",
"requests_wrapper=RequestsWrapper(headers=headers)\n",
"requests_wrapper=TextRequestsWrapper(headers=headers)\n",
"openapi_toolkit = OpenAPIToolkit.from_llm(OpenAI(temperature=0), json_spec, requests_wrapper, verbose=True)\n",
"openapi_agent_executor = create_openapi_agent(\n",
" llm=OpenAI(temperature=0),\n",

View File

@ -17,7 +17,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain.utilities import RequestsWrapper"
"from langchain.utilities import TextRequestsWrapper"
]
},
{
@ -27,7 +27,7 @@
"metadata": {},
"outputs": [],
"source": [
"requests = RequestsWrapper()"
"requests = TextRequestsWrapper()"
]
},
{

View File

@ -10,7 +10,7 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.openapi.prompt import DESCRIPTION
from langchain.agents.tools import Tool
from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from langchain.tools import BaseTool
from langchain.tools.json.tool import JsonSpec
from langchain.tools.requests.tool import (
@ -25,7 +25,7 @@ from langchain.tools.requests.tool import (
class RequestsToolkit(BaseToolkit):
"""Toolkit for making requests."""
requests_wrapper: RequestsWrapper
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Return a list of tools."""
@ -42,7 +42,7 @@ class OpenAPIToolkit(BaseToolkit):
"""Toolkit for interacting with a OpenAPI api."""
json_agent: AgentExecutor
requests_wrapper: RequestsWrapper
requests_wrapper: TextRequestsWrapper
def get_tools(self) -> List[BaseTool]:
"""Get the tools in the toolkit."""
@ -59,7 +59,7 @@ class OpenAPIToolkit(BaseToolkit):
cls,
llm: BaseLLM,
json_spec: JsonSpec,
requests_wrapper: RequestsWrapper,
requests_wrapper: TextRequestsWrapper,
**kwargs: Any,
) -> OpenAPIToolkit:
"""Create json agent from llm, then initialize."""

View File

@ -10,7 +10,7 @@ from langchain.chains.api.base import APIChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain
from langchain.llms.base import BaseLLM
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from langchain.tools.base import BaseTool
from langchain.tools.bing_search.tool import BingSearchRun
from langchain.tools.google_search.tool import GoogleSearchResults, GoogleSearchRun
@ -42,23 +42,23 @@ def _get_python_repl() -> BaseTool:
def _get_tools_requests_get() -> BaseTool:
return RequestsGetTool(requests_wrapper=RequestsWrapper())
return RequestsGetTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_post() -> BaseTool:
return RequestsPostTool(requests_wrapper=RequestsWrapper())
return RequestsPostTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_patch() -> BaseTool:
return RequestsPatchTool(requests_wrapper=RequestsWrapper())
return RequestsPatchTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_put() -> BaseTool:
return RequestsPutTool(requests_wrapper=RequestsWrapper())
return RequestsPutTool(requests_wrapper=TextRequestsWrapper())
def _get_tools_requests_delete() -> BaseTool:
return RequestsDeleteTool(requests_wrapper=RequestsWrapper())
return RequestsDeleteTool(requests_wrapper=TextRequestsWrapper())
def _get_terminal() -> BaseTool:

View File

@ -9,7 +9,7 @@ from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from langchain.schema import BaseLanguageModel
@ -18,7 +18,7 @@ class APIChain(Chain, BaseModel):
api_request_chain: LLMChain
api_answer_chain: LLMChain
requests_wrapper: RequestsWrapper = Field(exclude=True)
requests_wrapper: TextRequestsWrapper = Field(exclude=True)
api_docs: str
question_key: str = "question" #: :meta private:
output_key: str = "output" #: :meta private:
@ -93,7 +93,7 @@ class APIChain(Chain, BaseModel):
) -> APIChain:
"""Load chain from just an LLM and the api docs."""
get_request_chain = LLMChain(llm=llm, prompt=api_url_prompt)
requests_wrapper = RequestsWrapper(headers=headers)
requests_wrapper = TextRequestsWrapper(headers=headers)
get_answer_chain = LLMChain(llm=llm, prompt=api_response_prompt)
return cls(
api_request_chain=get_request_chain,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains import LLMChain
from langchain.chains.base import Chain
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
DEFAULT_HEADERS = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" # noqa: E501
@ -18,8 +18,8 @@ class LLMRequestsChain(Chain, BaseModel):
"""Chain that hits a URL and then uses an LLM to parse results."""
llm_chain: LLMChain
requests_wrapper: RequestsWrapper = Field(
default_factory=RequestsWrapper, exclude=True
requests_wrapper: TextRequestsWrapper = Field(
default_factory=TextRequestsWrapper, exclude=True
)
text_length: int = 8000
requests_key: str = "requests_result" #: :meta private:

View File

@ -6,8 +6,12 @@ import requests
from pydantic import BaseModel, Extra
class RequestsWrapper(BaseModel):
"""Lightweight wrapper around requests library."""
class Requests(BaseModel):
"""Wrapper around requests to handle auth and async.
The main purpose of this wrapper is to handle authentication (by saving
headers) and enable easy async methods on the same base object.
"""
headers: Optional[Dict[str, str]] = None
aiosession: Optional[aiohttp.ClientSession] = None
@ -18,56 +22,133 @@ class RequestsWrapper(BaseModel):
extra = Extra.forbid
arbitrary_types_allowed = True
def get(self, url: str, **kwargs: Any) -> str:
def get(self, url: str, **kwargs: Any) -> requests.Response:
"""GET the URL and return the text."""
return requests.get(url, headers=self.headers, **kwargs).text
return requests.get(url, headers=self.headers, **kwargs)
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""POST to the URL and return the text."""
return requests.post(url, json=data, headers=self.headers, **kwargs).text
return requests.post(url, json=data, headers=self.headers, **kwargs)
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""PATCH the URL and return the text."""
return requests.patch(url, json=data, headers=self.headers, **kwargs).text
return requests.patch(url, json=data, headers=self.headers, **kwargs)
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response:
"""PUT the URL and return the text."""
return requests.put(url, json=data, headers=self.headers, **kwargs).text
return requests.put(url, json=data, headers=self.headers, **kwargs)
def delete(self, url: str, **kwargs: Any) -> str:
def delete(self, url: str, **kwargs: Any) -> requests.Response:
"""DELETE the URL and return the text."""
return requests.delete(url, headers=self.headers, **kwargs).text
return requests.delete(url, headers=self.headers, **kwargs)
async def _arequest(self, method: str, url: str, **kwargs: Any) -> str:
async def _arequest(
self, method: str, url: str, **kwargs: Any
) -> aiohttp.ClientResponse:
"""Make an async request."""
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.request(
method, url, headers=self.headers, **kwargs
) as response:
return await response.text()
return response
else:
async with self.aiosession.request(
method, url, headers=self.headers, **kwargs
) as response:
return await response.text()
return response
async def aget(self, url: str, **kwargs: Any) -> str:
async def aget(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""GET the URL and return the text asynchronously."""
return await self._arequest("GET", url, **kwargs)
async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
async def apost(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""POST to the URL and return the text asynchronously."""
return await self._arequest("POST", url, json=data, **kwargs)
async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
async def apatch(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""PATCH the URL and return the text asynchronously."""
return await self._arequest("PATCH", url, json=data, **kwargs)
async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
async def aput(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> aiohttp.ClientResponse:
"""PUT the URL and return the text asynchronously."""
return await self._arequest("PUT", url, json=data, **kwargs)
async def adelete(self, url: str, **kwargs: Any) -> str:
async def adelete(self, url: str, **kwargs: Any) -> aiohttp.ClientResponse:
"""DELETE the URL and return the text asynchronously."""
return await self._arequest("DELETE", url, **kwargs)
class TextRequestsWrapper(BaseModel):
"""Lightweight wrapper around requests library.
The main purpose of this wrapper is to always return a text output.
"""
headers: Optional[Dict[str, str]] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def requests(self) -> Requests:
return Requests(headers=self.headers, aiosession=self.aiosession)
def get(self, url: str, **kwargs: Any) -> str:
"""GET the URL and return the text."""
return self.requests.get(url, **kwargs).text
def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""POST to the URL and return the text."""
return self.requests.post(url, json=data, headers=self.headers, **kwargs).text
def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PATCH the URL and return the text."""
return self.requests.patch(url, json=data, headers=self.headers, **kwargs).text
def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PUT the URL and return the text."""
return self.requests.put(url, json=data, headers=self.headers, **kwargs).text
def delete(self, url: str, **kwargs: Any) -> str:
"""DELETE the URL and return the text."""
return self.requests.delete(url, headers=self.headers, **kwargs).text
async def aget(self, url: str, **kwargs: Any) -> str:
"""GET the URL and return the text asynchronously."""
response = await self.requests.aget(url, **kwargs)
return await response.text()
async def apost(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""POST to the URL and return the text asynchronously."""
response = await self.requests.apost(url, data, **kwargs)
return await response.text()
async def apatch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PATCH the URL and return the text asynchronously."""
response = await self.requests.apatch(url, data, **kwargs)
return await response.text()
async def aput(self, url: str, data: Dict[str, Any], **kwargs: Any) -> str:
"""PUT the URL and return the text asynchronously."""
response = await self.requests.aput(url, data, **kwargs)
return await response.text()
async def adelete(self, url: str, **kwargs: Any) -> str:
"""DELETE the URL and return the text asynchronously."""
response = await self.requests.adelete(url, **kwargs)
return await response.text()
# For backwards compatibility
RequestsWrapper = TextRequestsWrapper

View File

@ -5,7 +5,7 @@ from typing import Any, Dict
from pydantic import BaseModel
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from langchain.tools.base import BaseTool
@ -17,7 +17,7 @@ def _parse_input(text: str) -> Dict[str, Any]:
class BaseRequestsTool(BaseModel):
"""Base class for requests tools."""
requests_wrapper: RequestsWrapper
requests_wrapper: TextRequestsWrapper
class RequestsGetTool(BaseRequestsTool, BaseTool):

View File

@ -1,6 +1,6 @@
"""General utilities."""
from langchain.python import PythonREPL
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from langchain.utilities.apify import ApifyWrapper
from langchain.utilities.bash import BashProcess
from langchain.utilities.bing_search import BingSearchAPIWrapper
@ -15,7 +15,7 @@ from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
__all__ = [
"ApifyWrapper",
"BashProcess",
"RequestsWrapper",
"TextRequestsWrapper",
"PythonREPL",
"GoogleSearchAPIWrapper",
"GoogleSerperAPIWrapper",

View File

@ -8,11 +8,11 @@ import pytest
from langchain import LLMChain
from langchain.chains.api.base import APIChain
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.requests import RequestsWrapper
from langchain.requests import TextRequestsWrapper
from tests.unit_tests.llms.fake_llm import FakeLLM
class FakeRequestsChain(RequestsWrapper):
class FakeRequestsChain(TextRequestsWrapper):
"""Fake requests chain just for testing purposes."""
output: str