From 663b0933e488383e6a9bc2a04b4b1cf866a8ea94 Mon Sep 17 00:00:00 2001 From: Eric Speidel <12648651+EricSpeidel@users.noreply.github.com> Date: Fri, 14 Jul 2023 14:38:24 +0200 Subject: [PATCH] Allow passing auth objects in TextRequestsWrapper (#7701) - Description: This allows passing auth objects in request wrappers. Currently, we can handle auth by editing headers in the RequestsWrappers, but more complex auth methods, such as Kerberos, could be handled better by using existing functionality within the requests library. There are many authentication options supported both natively and by extensions, such as requests-kerberos or requests-ntlm. - Issue: Fixes #7542 - Dependencies: none Co-authored-by: eric.speidel@de.bosch.com --- langchain/requests.py | 42 +++++++++++++++++++++++++++++------------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/langchain/requests.py b/langchain/requests.py index 536687278b..2891701c49 100644 --- a/langchain/requests.py +++ b/langchain/requests.py @@ -16,6 +16,7 @@ class Requests(BaseModel): headers: Optional[Dict[str, str]] = None aiosession: Optional[aiohttp.ClientSession] = None + auth: Optional[Any] = None class Config: """Configuration for this pydantic object.""" @@ -25,23 +26,29 @@ class Requests(BaseModel): def get(self, url: str, **kwargs: Any) -> requests.Response: """GET the URL and return the text.""" - return requests.get(url, headers=self.headers, **kwargs) + return requests.get(url, headers=self.headers, auth=self.auth, **kwargs) def post(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """POST to the URL and return the text.""" - return requests.post(url, json=data, headers=self.headers, **kwargs) + return requests.post( + url, json=data, headers=self.headers, auth=self.auth, **kwargs + ) def patch(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """PATCH the URL and return the text.""" - return requests.patch(url, json=data, headers=self.headers, **kwargs) + return requests.patch( + url, json=data, headers=self.headers, auth=self.auth, **kwargs + ) def put(self, url: str, data: Dict[str, Any], **kwargs: Any) -> requests.Response: """PUT the URL and return the text.""" - return requests.put(url, json=data, headers=self.headers, **kwargs) + return requests.put( + url, json=data, headers=self.headers, auth=self.auth, **kwargs + ) def delete(self, url: str, **kwargs: Any) -> requests.Response: """DELETE the URL and return the text.""" - return requests.delete(url, headers=self.headers, **kwargs) + return requests.delete(url, headers=self.headers, auth=self.auth, **kwargs) @asynccontextmanager async def _arequest( @@ -51,12 +58,12 @@ class Requests(BaseModel): if not self.aiosession: async with aiohttp.ClientSession() as session: async with session.request( - method, url, headers=self.headers, **kwargs + method, url, headers=self.headers, auth=self.auth, **kwargs ) as response: yield response else: async with self.aiosession.request( - method, url, headers=self.headers, **kwargs + method, url, headers=self.headers, auth=self.auth, **kwargs ) as response: yield response @@ -65,7 +72,7 @@ class Requests(BaseModel): self, url: str, **kwargs: Any ) -> AsyncGenerator[aiohttp.ClientResponse, None]: """GET the URL and return the text asynchronously.""" - async with self._arequest("GET", url, **kwargs) as response: + async with self._arequest("GET", url, auth=self.auth, **kwargs) as response: yield response @asynccontextmanager @@ -73,7 +80,9 @@ class Requests(BaseModel): self, url: str, data: Dict[str, Any], **kwargs: Any ) -> AsyncGenerator[aiohttp.ClientResponse, None]: """POST to the URL and return the text asynchronously.""" - async with self._arequest("POST", url, json=data, **kwargs) as response: + async with self._arequest( + "POST", url, json=data, auth=self.auth, **kwargs + ) as response: yield response @asynccontextmanager @@ -81,7 +90,9 @@ class Requests(BaseModel): self, url: str, data: Dict[str, Any], **kwargs: Any ) -> AsyncGenerator[aiohttp.ClientResponse, None]: """PATCH the URL and return the text asynchronously.""" - async with self._arequest("PATCH", url, json=data, **kwargs) as response: + async with self._arequest( + "PATCH", url, json=data, auth=self.auth, **kwargs + ) as response: yield response @asynccontextmanager @@ -89,7 +100,9 @@ class Requests(BaseModel): self, url: str, data: Dict[str, Any], **kwargs: Any ) -> AsyncGenerator[aiohttp.ClientResponse, None]: """PUT the URL and return the text asynchronously.""" - async with self._arequest("PUT", url, json=data, **kwargs) as response: + async with self._arequest( + "PUT", url, json=data, auth=self.auth, **kwargs + ) as response: yield response @asynccontextmanager @@ -97,7 +110,7 @@ class Requests(BaseModel): self, url: str, **kwargs: Any ) -> AsyncGenerator[aiohttp.ClientResponse, None]: """DELETE the URL and return the text asynchronously.""" - async with self._arequest("DELETE", url, **kwargs) as response: + async with self._arequest("DELETE", url, auth=self.auth, **kwargs) as response: yield response @@ -109,6 +122,7 @@ class TextRequestsWrapper(BaseModel): headers: Optional[Dict[str, str]] = None aiosession: Optional[aiohttp.ClientSession] = None + auth: Optional[Any] = None class Config: """Configuration for this pydantic object.""" @@ -118,7 +132,9 @@ class TextRequestsWrapper(BaseModel): @property def requests(self) -> Requests: - return Requests(headers=self.headers, aiosession=self.aiosession) + return Requests( + headers=self.headers, aiosession=self.aiosession, auth=self.auth + ) def get(self, url: str, **kwargs: Any) -> str: """GET the URL and return the text."""