mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
4384fa8e49
[Dria](https://dria.co/) is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This PR adds a retriever that can retrieve documents from Dria.
88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
"""Wrapper around Dria Retriever."""
|
|
|
|
from typing import Any, List, Optional
|
|
|
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
|
|
from langchain_community.utilities import DriaAPIWrapper
|
|
|
|
|
|
class DriaRetriever(BaseRetriever):
|
|
"""`Dria` retriever using the DriaAPIWrapper."""
|
|
|
|
api_wrapper: DriaAPIWrapper
|
|
|
|
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
|
|
"""
|
|
Initialize the DriaRetriever with a DriaAPIWrapper instance.
|
|
|
|
Args:
|
|
api_key: The API key for Dria.
|
|
contract_id: The contract ID of the knowledge base to interact with.
|
|
"""
|
|
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
|
|
super().__init__(api_wrapper=api_wrapper, **kwargs)
|
|
|
|
def create_knowledge_base(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
category: str = "Unspecified",
|
|
embedding: str = "jina",
|
|
) -> str:
|
|
"""Create a new knowledge base in Dria.
|
|
|
|
Args:
|
|
name: The name of the knowledge base.
|
|
description: The description of the knowledge base.
|
|
category: The category of the knowledge base.
|
|
embedding: The embedding model to use for the knowledge base.
|
|
|
|
|
|
Returns:
|
|
The ID of the created knowledge base.
|
|
"""
|
|
response = self.api_wrapper.create_knowledge_base(
|
|
name, description, category, embedding
|
|
)
|
|
return response
|
|
|
|
def add_texts(
|
|
self,
|
|
texts: List,
|
|
) -> None:
|
|
"""Add texts to the Dria knowledge base.
|
|
|
|
Args:
|
|
texts: An iterable of texts and metadatas to add to the knowledge base.
|
|
|
|
Returns:
|
|
List of IDs representing the added texts.
|
|
"""
|
|
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
|
|
self.api_wrapper.insert_data(data)
|
|
|
|
def _get_relevant_documents(
|
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
|
) -> List[Document]:
|
|
"""Retrieve relevant documents from Dria based on a query.
|
|
|
|
Args:
|
|
query: The query string to search for in the knowledge base.
|
|
run_manager: Callback manager for the retriever run.
|
|
|
|
Returns:
|
|
A list of Documents containing the search results.
|
|
"""
|
|
results = self.api_wrapper.search(query)
|
|
docs = [
|
|
Document(
|
|
page_content=result["metadata"],
|
|
metadata={"id": result["id"], "score": result["score"]},
|
|
)
|
|
for result in results
|
|
]
|
|
return docs
|