You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/community/langchain_community/retrievers/chaindesk.py

93 lines
2.6 KiB
Python

from typing import Any, List, Optional
import aiohttp
import requests
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ChaindeskRetriever(BaseRetriever):
"""`Chaindesk API` retriever."""
datastore_url: str
top_k: Optional[int]
api_key: Optional[str]
def __init__(
self,
datastore_url: str,
top_k: Optional[int] = None,
api_key: Optional[str] = None,
):
self.datastore_url = datastore_url
self.api_key = api_key
self.top_k = top_k
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
response = requests.post(
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
)
data = response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
async with aiohttp.ClientSession() as session:
async with session.request(
"POST",
self.datastore_url,
json={
"query": query,
**({"topK": self.top_k} if self.top_k is not None else {}),
},
headers={
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {self.api_key}"}
if self.api_key is not None
else {}
),
},
) as response:
data = await response.json()
return [
Document(
page_content=r["text"],
metadata={"source": r["source"], "score": r["score"]},
)
for r in data["results"]
]