forked from Archives/langchain
remote retriever (#2232)
parent
5a0844bae1
commit
582950291c
@ -1,3 +1,4 @@
|
|||||||
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
|
||||||
|
from langchain.retrievers.remote_retriever import RemoteLangChainRetriever
|
||||||
|
|
||||||
__all__ = ["ChatGPTPluginRetriever"]
|
__all__ = ["ChatGPTPluginRetriever", "RemoteLangChainRetriever"]
|
||||||
|
@ -0,0 +1,29 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.schema import BaseRetriever, Document
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteLangChainRetriever(BaseRetriever, BaseModel):
|
||||||
|
url: str
|
||||||
|
headers: Optional[dict] = None
|
||||||
|
input_key: str = "message"
|
||||||
|
response_key: str = "response"
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
response = requests.post(
|
||||||
|
self.url, json={self.input_key: query}, headers=self.headers
|
||||||
|
)
|
||||||
|
result = response.json()
|
||||||
|
return [Document(**r) for r in result[self.response_key]]
|
||||||
|
|
||||||
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.request(
|
||||||
|
"POST", self.url, headers=self.headers, json={self.input_key: query}
|
||||||
|
) as response:
|
||||||
|
result = await response.json()
|
||||||
|
return [Document(**r) for r in result[self.response_key]]
|
Loading…
Reference in New Issue