From 582950291c07392762ecb5c6eecc30c898590a45 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 1 Apr 2023 08:59:04 -0700 Subject: [PATCH] remote retriever (#2232) --- langchain/retrievers/__init__.py | 3 ++- langchain/retrievers/remote_retriever.py | 29 ++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 langchain/retrievers/remote_retriever.py diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 5007aef3af..6742f5bcc8 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,3 +1,4 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever +from langchain.retrievers.remote_retriever import RemoteLangChainRetriever -__all__ = ["ChatGPTPluginRetriever"] +__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever"] diff --git a/langchain/retrievers/remote_retriever.py b/langchain/retrievers/remote_retriever.py new file mode 100644 index 0000000000..958e6d74f9 --- /dev/null +++ b/langchain/retrievers/remote_retriever.py @@ -0,0 +1,29 @@ +from typing import List, Optional + +import aiohttp +import requests +from pydantic import BaseModel + +from langchain.schema import BaseRetriever, Document + + +class RemoteLangChainRetriever(BaseRetriever, BaseModel): + url: str + headers: Optional[dict] = None + input_key: str = "message" + response_key: str = "response" + + def get_relevant_documents(self, query: str) -> List[Document]: + response = requests.post( + self.url, json={self.input_key: query}, headers=self.headers + ) + result = response.json() + return [Document(**r) for r in result[self.response_key]] + + async def aget_relevant_documents(self, query: str) -> List[Document]: + async with aiohttp.ClientSession() as session: + async with session.request( + "POST", self.url, headers=self.headers, json={self.input_key: query} + ) as response: + result = await response.json() + return [Document(**r) for r in result[self.response_key]]