mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
community[minor]: Add Dria retriever (#17098)
[Dria](https://dria.co/) is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This PR adds a retriever that can retrieve documents from Dria.
This commit is contained in:
parent
0b0a55192f
commit
4384fa8e49
191
docs/docs/integrations/retrievers/dria_index.ipynb
Normal file
191
docs/docs/integrations/retrievers/dria_index.ipynb
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "UYyFIEKEkmHb"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Dria\n",
|
||||||
|
"\n",
|
||||||
|
"Dria is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This notebook demonstrates how to use the Dria API for data retrieval tasks."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "VNTFUgK9kmHd"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Installation\n",
|
||||||
|
"\n",
|
||||||
|
"Ensure you have the `dria` package installed. You can install it using pip:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "X--1A8EEkmHd"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install --upgrade --quiet dria"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "xRbRL0SgkmHe"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Configure API Key\n",
|
||||||
|
"\n",
|
||||||
|
"Set up your Dria API key for access."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"id": "hGqOByNMkmHe"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"\n",
|
||||||
|
"os.environ[\"DRIA_API_KEY\"] = \"DRIA_API_KEY\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "nDfAEqQtkmHe"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Initialize Dria Retriever\n",
|
||||||
|
"\n",
|
||||||
|
"Create an instance of `DriaRetriever`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "vlyorgCckmHe"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.retrievers import DriaRetriever\n",
|
||||||
|
"\n",
|
||||||
|
"api_key = os.getenv(\"DRIA_API_KEY\")\n",
|
||||||
|
"retriever = DriaRetriever(api_key=api_key)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "j7WUY5jBOLQd"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# **Create Knowledge Base**\n",
|
||||||
|
"\n",
|
||||||
|
"Create a knowledge on [Dria's Knowledge Hub](https://dria.co/knowledge)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "L5ER81eWOKnt"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"contract_id = retriever.create_knowledge_base(\n",
|
||||||
|
" name=\"France's AI Development\",\n",
|
||||||
|
" embedding=DriaRetriever.models.jina_embeddings_v2_base_en.value,\n",
|
||||||
|
" category=\"Artificial Intelligence\",\n",
|
||||||
|
" description=\"Explore the growth and contributions of France in the field of Artificial Intelligence.\",\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "9VCTzSFpkmHe"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Add Data\n",
|
||||||
|
"\n",
|
||||||
|
"Load data into your Dria knowledge base."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "xeTMafIekmHf"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"texts = [\n",
|
||||||
|
" \"The first text to add to Dria.\",\n",
|
||||||
|
" \"Another piece of information to store.\",\n",
|
||||||
|
" \"More data to include in the Dria knowledge base.\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"ids = retriever.add_texts(texts)\n",
|
||||||
|
"print(\"Data added with IDs:\", ids)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "dy1UlvLCkmHf"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Retrieve Data\n",
|
||||||
|
"\n",
|
||||||
|
"Use the retriever to find relevant documents given a query."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "9y3msv9tkmHf"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"Find information about Dria.\"\n",
|
||||||
|
"result = retriever.get_relevant_documents(query)\n",
|
||||||
|
"for doc in result:\n",
|
||||||
|
" print(doc)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.x"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -33,6 +33,7 @@ _module_lookup = {
|
|||||||
"ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever",
|
"ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever",
|
||||||
"CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever",
|
"CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever",
|
||||||
"DocArrayRetriever": "langchain_community.retrievers.docarray",
|
"DocArrayRetriever": "langchain_community.retrievers.docarray",
|
||||||
|
"DriaRetriever": "langchain_community.retrievers.dria_index",
|
||||||
"ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25",
|
"ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25",
|
||||||
"EmbedchainRetriever": "langchain_community.retrievers.embedchain",
|
"EmbedchainRetriever": "langchain_community.retrievers.embedchain",
|
||||||
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
|
"GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501
|
||||||
|
87
libs/community/langchain_community/retrievers/dria_index.py
Normal file
87
libs/community/langchain_community/retrievers/dria_index.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""Wrapper around Dria Retriever."""
|
||||||
|
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
|
||||||
|
from langchain_community.utilities import DriaAPIWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class DriaRetriever(BaseRetriever):
|
||||||
|
"""`Dria` retriever using the DriaAPIWrapper."""
|
||||||
|
|
||||||
|
api_wrapper: DriaAPIWrapper
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Initialize the DriaRetriever with a DriaAPIWrapper instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The API key for Dria.
|
||||||
|
contract_id: The contract ID of the knowledge base to interact with.
|
||||||
|
"""
|
||||||
|
api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id)
|
||||||
|
super().__init__(api_wrapper=api_wrapper, **kwargs)
|
||||||
|
|
||||||
|
def create_knowledge_base(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
category: str = "Unspecified",
|
||||||
|
embedding: str = "jina",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new knowledge base in Dria.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The name of the knowledge base.
|
||||||
|
description: The description of the knowledge base.
|
||||||
|
category: The category of the knowledge base.
|
||||||
|
embedding: The embedding model to use for the knowledge base.
|
||||||
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The ID of the created knowledge base.
|
||||||
|
"""
|
||||||
|
response = self.api_wrapper.create_knowledge_base(
|
||||||
|
name, description, category, embedding
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: List,
|
||||||
|
) -> None:
|
||||||
|
"""Add texts to the Dria knowledge base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: An iterable of texts and metadatas to add to the knowledge base.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of IDs representing the added texts.
|
||||||
|
"""
|
||||||
|
data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts]
|
||||||
|
self.api_wrapper.insert_data(data)
|
||||||
|
|
||||||
|
def _get_relevant_documents(
|
||||||
|
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Retrieve relevant documents from Dria based on a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query string to search for in the knowledge base.
|
||||||
|
run_manager: Callback manager for the retriever run.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of Documents containing the search results.
|
||||||
|
"""
|
||||||
|
results = self.api_wrapper.search(query)
|
||||||
|
docs = [
|
||||||
|
Document(
|
||||||
|
page_content=result["metadata"],
|
||||||
|
metadata={"id": result["id"], "score": result["score"]},
|
||||||
|
)
|
||||||
|
for result in results
|
||||||
|
]
|
||||||
|
return docs
|
@ -15,6 +15,7 @@ _module_lookup = {
|
|||||||
"BibtexparserWrapper": "langchain_community.utilities.bibtex",
|
"BibtexparserWrapper": "langchain_community.utilities.bibtex",
|
||||||
"BingSearchAPIWrapper": "langchain_community.utilities.bing_search",
|
"BingSearchAPIWrapper": "langchain_community.utilities.bing_search",
|
||||||
"BraveSearchWrapper": "langchain_community.utilities.brave_search",
|
"BraveSearchWrapper": "langchain_community.utilities.brave_search",
|
||||||
|
"DriaAPIWrapper": "langchain_community.utilities.dria_index",
|
||||||
"DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search",
|
"DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search",
|
||||||
"GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query",
|
"GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query",
|
||||||
"GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance",
|
"GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance",
|
||||||
|
95
libs/community/langchain_community/utilities/dria_index.py
Normal file
95
libs/community/langchain_community/utilities/dria_index.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DriaAPIWrapper:
|
||||||
|
"""Wrapper around Dria API.
|
||||||
|
|
||||||
|
This wrapper facilitates interactions with Dria's vector search
|
||||||
|
and retrieval services, including creating knowledge bases, inserting data,
|
||||||
|
and fetching search results.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key: Your API key for accessing Dria.
|
||||||
|
contract_id: The contract ID of the knowledge base to interact with.
|
||||||
|
top_n: Number of top results to fetch for a search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from dria import Dria, Models
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"""Dria is not installed. Please install Dria to use this wrapper.
|
||||||
|
|
||||||
|
You can install Dria using the following command:
|
||||||
|
pip install dria
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.api_key = api_key
|
||||||
|
self.models = Models
|
||||||
|
self.contract_id = contract_id
|
||||||
|
self.top_n = top_n
|
||||||
|
self.dria_client = Dria(api_key=self.api_key)
|
||||||
|
if self.contract_id:
|
||||||
|
self.dria_client.set_contract(self.contract_id)
|
||||||
|
|
||||||
|
def create_knowledge_base(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
category: str,
|
||||||
|
embedding: str,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new knowledge base."""
|
||||||
|
contract_id = self.dria_client.create(
|
||||||
|
name=name, embedding=embedding, category=category, description=description
|
||||||
|
)
|
||||||
|
logger.info(f"Knowledge base created with ID: {contract_id}")
|
||||||
|
self.contract_id = contract_id
|
||||||
|
return contract_id
|
||||||
|
|
||||||
|
def insert_data(self, data: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Insert data into the knowledge base."""
|
||||||
|
response = self.dria_client.insert_text(data)
|
||||||
|
logger.info(f"Data inserted: {response}")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def search(self, query: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Perform a text-based search."""
|
||||||
|
results = self.dria_client.search(query, top_n=self.top_n)
|
||||||
|
logger.info(f"Search results: {results}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]:
|
||||||
|
"""Perform a vector-based query."""
|
||||||
|
vector_query_results = self.dria_client.query(vector, top_n=self.top_n)
|
||||||
|
logger.info(f"Vector query results: {vector_query_results}")
|
||||||
|
return vector_query_results
|
||||||
|
|
||||||
|
def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""Method to handle both text-based searches and vector-based queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: A string for text-based search or a list of floats for
|
||||||
|
vector-based query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The search or query results from Dria.
|
||||||
|
"""
|
||||||
|
if isinstance(query, str):
|
||||||
|
return self.search(query)
|
||||||
|
elif isinstance(query, list) and all(isinstance(item, float) for item in query):
|
||||||
|
return self.query_with_vector(query)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"""Invalid query type. Please provide a string for text search or a
|
||||||
|
list of floats for vector query."""
|
||||||
|
)
|
||||||
|
return None
|
@ -0,0 +1,41 @@
|
|||||||
|
import pytest
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain_community.retrievers import DriaRetriever
|
||||||
|
|
||||||
|
|
||||||
|
# Set a fixture for DriaRetriever
|
||||||
|
@pytest.fixture
|
||||||
|
def dria_retriever() -> DriaRetriever:
|
||||||
|
api_key = "<YOUR_API_KEY>"
|
||||||
|
contract_id = "B16z9i3rRi0KEeibrzzMU33YTB4WDtos1vdiMBTmKgs"
|
||||||
|
retriever = DriaRetriever(api_key=api_key, contract_id=contract_id)
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
def test_dria_retriever(dria_retriever: DriaRetriever) -> None:
|
||||||
|
texts = [
|
||||||
|
{
|
||||||
|
"text": "Langchain",
|
||||||
|
"metadata": {
|
||||||
|
"source": "source#1",
|
||||||
|
"document_id": "doc123",
|
||||||
|
"content": "Langchain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
dria_retriever.add_texts(texts)
|
||||||
|
|
||||||
|
# Assuming get_relevant_documents returns a list of Document instances
|
||||||
|
docs = dria_retriever.get_relevant_documents("Langchain")
|
||||||
|
|
||||||
|
# Perform assertions
|
||||||
|
assert len(docs) > 0, "Expected at least one document"
|
||||||
|
doc = docs[0]
|
||||||
|
assert isinstance(doc, Document), "Expected a Document instance"
|
||||||
|
assert isinstance(doc.page_content, str), (
|
||||||
|
"Expected document content type " "to be string"
|
||||||
|
)
|
||||||
|
assert isinstance(
|
||||||
|
doc.metadata, dict
|
||||||
|
), "Expected document metadata content to be a dictionary"
|
@ -10,6 +10,7 @@ EXPECTED_ALL = [
|
|||||||
"ChatGPTPluginRetriever",
|
"ChatGPTPluginRetriever",
|
||||||
"ChaindeskRetriever",
|
"ChaindeskRetriever",
|
||||||
"CohereRagRetriever",
|
"CohereRagRetriever",
|
||||||
|
"DriaRetriever",
|
||||||
"ElasticSearchBM25Retriever",
|
"ElasticSearchBM25Retriever",
|
||||||
"EmbedchainRetriever",
|
"EmbedchainRetriever",
|
||||||
"GoogleDocumentAIWarehouseRetriever",
|
"GoogleDocumentAIWarehouseRetriever",
|
||||||
|
@ -9,6 +9,7 @@ EXPECTED_ALL = [
|
|||||||
"BingSearchAPIWrapper",
|
"BingSearchAPIWrapper",
|
||||||
"BraveSearchWrapper",
|
"BraveSearchWrapper",
|
||||||
"DuckDuckGoSearchAPIWrapper",
|
"DuckDuckGoSearchAPIWrapper",
|
||||||
|
"DriaAPIWrapper",
|
||||||
"GoldenQueryAPIWrapper",
|
"GoldenQueryAPIWrapper",
|
||||||
"GoogleFinanceAPIWrapper",
|
"GoogleFinanceAPIWrapper",
|
||||||
"GoogleJobsAPIWrapper",
|
"GoogleJobsAPIWrapper",
|
||||||
|
Loading…
Reference in New Issue
Block a user