You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/retrievers/databerry.py

75 lines
2.3 KiB
Python

from typing import List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class DataberryRetriever(BaseRetriever):
"""`Databerry API` retriever."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
)
data = response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
) as response:
data = await response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]