langchain/libs/community/langchain_community/retrievers/dria_index.py

88 lines
2.7 KiB
Python
Raw Normal View History

"""Wrapper around Dria Retriever."""
from typing import Any, List, Optional
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_community.utilities import DriaAPIWrapper
class DriaRetriever(BaseRetriever):
"""`Dria` retriever using the DriaAPIWrapper."""
api_wrapper: DriaAPIWrapper
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
"""
Initialize the DriaRetriever with a DriaAPIWrapper instance.
Args:
api_key: The API key for Dria.
contract_id: The contract ID of the knowledge base to interact with.
"""
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
super().__init__(api_wrapper=api_wrapper, **kwargs) # type: ignore[call-arg]
def create_knowledge_base(
self,
name: str,
description: str,
category: str = "Unspecified",
embedding: str = "jina",
) -> str:
"""Create a new knowledge base in Dria.
Args:
name: The name of the knowledge base.
description: The description of the knowledge base.
category: The category of the knowledge base.
embedding: The embedding model to use for the knowledge base.
Returns:
The ID of the created knowledge base.
"""
response = self.api_wrapper.create_knowledge_base(
name, description, category, embedding
)
return response
def add_texts(
self,
texts: List,
) -> None:
"""Add texts to the Dria knowledge base.
Args:
texts: An iterable of texts and metadatas to add to the knowledge base.
Returns:
List of IDs representing the added texts.
"""
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
self.api_wrapper.insert_data(data)
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Retrieve relevant documents from Dria based on a query.
Args:
query: The query string to search for in the knowledge base.
run_manager: Callback manager for the retriever run.
Returns:
A list of Documents containing the search results.
"""
results = self.api_wrapper.search(query)
docs = [
Document(
page_content=result["metadata"],
metadata={"id": result["id"], "score": result["score"]},
)
for result in results
]
return docs