diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 5007aef3..6742f5bc 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -1,3 +1,4 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever +from langchain.retrievers.remote_retriever import RemoteLangChainRetriever -__all__ = ["ChatGPTPluginRetriever"] +__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever"] diff --git a/langchain/retrievers/remote_retriever.py b/langchain/retrievers/remote_retriever.py new file mode 100644 index 00000000..958e6d74 --- /dev/null +++ b/langchain/retrievers/remote_retriever.py @@ -0,0 +1,29 @@ +from typing import List, Optional + +import aiohttp +import requests +from pydantic import BaseModel + +from langchain.schema import BaseRetriever, Document + + +class RemoteLangChainRetriever(BaseRetriever, BaseModel): + url: str + headers: Optional[dict] = None + input_key: str = "message" + response_key: str = "response" + + def get_relevant_documents(self, query: str) -> List[Document]: + response = requests.post( + self.url, json={self.input_key: query}, headers=self.headers + ) + result = response.json() + return [Document(**r) for r in result[self.response_key]] + + async def aget_relevant_documents(self, query: str) -> List[Document]: + async with aiohttp.ClientSession() as session: + async with session.request( + "POST", self.url, headers=self.headers, json={self.input_key: query} + ) as response: + result = await response.json() + return [Document(**r) for r in result[self.response_key]]