From 4384fa8e49af77895365ba11ddf1d5aa5a0eb17c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?An=C4=B1l=20Berk=20Altuner?= <107621925+anilaltuner@users.noreply.github.com> Date: Mon, 1 Apr 2024 22:04:19 +0300 Subject: [PATCH] community[minor]: Add Dria retriever (#17098) [Dria](https://dria.co/) is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This PR adds a retriever that can retrieve documents from Dria. --- .../integrations/retrievers/dria_index.ipynb | 191 ++++++++++++++++++ .../retrievers/__init__.py | 1 + .../retrievers/dria_index.py | 87 ++++++++ .../langchain_community/utilities/__init__.py | 1 + .../utilities/dria_index.py | 95 +++++++++ .../retrievers/test_dria_index.py | 41 ++++ .../unit_tests/retrievers/test_imports.py | 1 + .../unit_tests/utilities/test_imports.py | 1 + 8 files changed, 418 insertions(+) create mode 100644 docs/docs/integrations/retrievers/dria_index.ipynb create mode 100644 libs/community/langchain_community/retrievers/dria_index.py create mode 100644 libs/community/langchain_community/utilities/dria_index.py create mode 100644 libs/community/tests/integration_tests/retrievers/test_dria_index.py diff --git a/docs/docs/integrations/retrievers/dria_index.ipynb b/docs/docs/integrations/retrievers/dria_index.ipynb new file mode 100644 index 0000000000..ced1cb822c --- /dev/null +++ b/docs/docs/integrations/retrievers/dria_index.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "UYyFIEKEkmHb" + }, + "source": [ + "# Dria\n", + "\n", + "Dria is a hub of public RAG models for developers to both contribute and utilize a shared embedding lake. This notebook demonstrates how to use the Dria API for data retrieval tasks." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VNTFUgK9kmHd" + }, + "source": [ + "# Installation\n", + "\n", + "Ensure you have the `dria` package installed. You can install it using pip:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X--1A8EEkmHd" + }, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet dria" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xRbRL0SgkmHe" + }, + "source": [ + "# Configure API Key\n", + "\n", + "Set up your Dria API key for access." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "hGqOByNMkmHe" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"DRIA_API_KEY\"] = \"DRIA_API_KEY\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nDfAEqQtkmHe" + }, + "source": [ + "# Initialize Dria Retriever\n", + "\n", + "Create an instance of `DriaRetriever`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vlyorgCckmHe" + }, + "outputs": [], + "source": [ + "from langchain.retrievers import DriaRetriever\n", + "\n", + "api_key = os.getenv(\"DRIA_API_KEY\")\n", + "retriever = DriaRetriever(api_key=api_key)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "j7WUY5jBOLQd" + }, + "source": [ + "# **Create Knowledge Base**\n", + "\n", + "Create a knowledge on [Dria's Knowledge Hub](https://dria.co/knowledge)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L5ER81eWOKnt" + }, + "outputs": [], + "source": [ + "contract_id = retriever.create_knowledge_base(\n", + " name=\"France's AI Development\",\n", + " embedding=DriaRetriever.models.jina_embeddings_v2_base_en.value,\n", + " category=\"Artificial Intelligence\",\n", + " description=\"Explore the growth and contributions of France in the field of Artificial Intelligence.\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9VCTzSFpkmHe" + }, + "source": [ + "# Add Data\n", + "\n", + "Load data into your Dria knowledge base." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xeTMafIekmHf" + }, + "outputs": [], + "source": [ + "texts = [\n", + " \"The first text to add to Dria.\",\n", + " \"Another piece of information to store.\",\n", + " \"More data to include in the Dria knowledge base.\",\n", + "]\n", + "\n", + "ids = retriever.add_texts(texts)\n", + "print(\"Data added with IDs:\", ids)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dy1UlvLCkmHf" + }, + "source": [ + "# Retrieve Data\n", + "\n", + "Use the retriever to find relevant documents given a query." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9y3msv9tkmHf" + }, + "outputs": [], + "source": [ + "query = \"Find information about Dria.\"\n", + "result = retriever.get_relevant_documents(query)\n", + "for doc in result:\n", + " print(doc)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.x" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/libs/community/langchain_community/retrievers/__init__.py b/libs/community/langchain_community/retrievers/__init__.py index c126f26928..7785d6ed69 100644 --- a/libs/community/langchain_community/retrievers/__init__.py +++ b/libs/community/langchain_community/retrievers/__init__.py @@ -33,6 +33,7 @@ _module_lookup = { "ChatGPTPluginRetriever": "langchain_community.retrievers.chatgpt_plugin_retriever", "CohereRagRetriever": "langchain_community.retrievers.cohere_rag_retriever", "DocArrayRetriever": "langchain_community.retrievers.docarray", + "DriaRetriever": "langchain_community.retrievers.dria_index", "ElasticSearchBM25Retriever": "langchain_community.retrievers.elastic_search_bm25", "EmbedchainRetriever": "langchain_community.retrievers.embedchain", "GoogleCloudEnterpriseSearchRetriever": "langchain_community.retrievers.google_vertex_ai_search", # noqa: E501 diff --git a/libs/community/langchain_community/retrievers/dria_index.py b/libs/community/langchain_community/retrievers/dria_index.py new file mode 100644 index 0000000000..5da93a804e --- /dev/null +++ b/libs/community/langchain_community/retrievers/dria_index.py @@ -0,0 +1,87 @@ +"""Wrapper around Dria Retriever.""" + +from typing import Any, List, Optional + +from langchain_core.callbacks import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from langchain_community.utilities import DriaAPIWrapper + + +class DriaRetriever(BaseRetriever): + """`Dria` retriever using the DriaAPIWrapper.""" + + api_wrapper: DriaAPIWrapper + + def __init__(self, api_key: str, contract_id: Optional[str] = None, **kwargs: Any): + """ + Initialize the DriaRetriever with a DriaAPIWrapper instance. + + Args: + api_key: The API key for Dria. + contract_id: The contract ID of the knowledge base to interact with. + """ + api_wrapper = DriaAPIWrapper(api_key=api_key, contract_id=contract_id) + super().__init__(api_wrapper=api_wrapper, **kwargs) + + def create_knowledge_base( + self, + name: str, + description: str, + category: str = "Unspecified", + embedding: str = "jina", + ) -> str: + """Create a new knowledge base in Dria. + + Args: + name: The name of the knowledge base. + description: The description of the knowledge base. + category: The category of the knowledge base. + embedding: The embedding model to use for the knowledge base. + + + Returns: + The ID of the created knowledge base. + """ + response = self.api_wrapper.create_knowledge_base( + name, description, category, embedding + ) + return response + + def add_texts( + self, + texts: List, + ) -> None: + """Add texts to the Dria knowledge base. + + Args: + texts: An iterable of texts and metadatas to add to the knowledge base. + + Returns: + List of IDs representing the added texts. + """ + data = [{"text": text["text"], "metadata": text["metadata"]} for text in texts] + self.api_wrapper.insert_data(data) + + def _get_relevant_documents( + self, query: str, *, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """Retrieve relevant documents from Dria based on a query. + + Args: + query: The query string to search for in the knowledge base. + run_manager: Callback manager for the retriever run. + + Returns: + A list of Documents containing the search results. + """ + results = self.api_wrapper.search(query) + docs = [ + Document( + page_content=result["metadata"], + metadata={"id": result["id"], "score": result["score"]}, + ) + for result in results + ] + return docs diff --git a/libs/community/langchain_community/utilities/__init__.py b/libs/community/langchain_community/utilities/__init__.py index 41b47f3390..64353727c5 100644 --- a/libs/community/langchain_community/utilities/__init__.py +++ b/libs/community/langchain_community/utilities/__init__.py @@ -15,6 +15,7 @@ _module_lookup = { "BibtexparserWrapper": "langchain_community.utilities.bibtex", "BingSearchAPIWrapper": "langchain_community.utilities.bing_search", "BraveSearchWrapper": "langchain_community.utilities.brave_search", + "DriaAPIWrapper": "langchain_community.utilities.dria_index", "DuckDuckGoSearchAPIWrapper": "langchain_community.utilities.duckduckgo_search", "GoldenQueryAPIWrapper": "langchain_community.utilities.golden_query", "GoogleFinanceAPIWrapper": "langchain_community.utilities.google_finance", diff --git a/libs/community/langchain_community/utilities/dria_index.py b/libs/community/langchain_community/utilities/dria_index.py new file mode 100644 index 0000000000..5174751dfc --- /dev/null +++ b/libs/community/langchain_community/utilities/dria_index.py @@ -0,0 +1,95 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +logger = logging.getLogger(__name__) + + +class DriaAPIWrapper: + """Wrapper around Dria API. + + This wrapper facilitates interactions with Dria's vector search + and retrieval services, including creating knowledge bases, inserting data, + and fetching search results. + + Attributes: + api_key: Your API key for accessing Dria. + contract_id: The contract ID of the knowledge base to interact with. + top_n: Number of top results to fetch for a search. + """ + + def __init__( + self, api_key: str, contract_id: Optional[str] = None, top_n: int = 10 + ): + try: + from dria import Dria, Models + except ImportError: + logger.error( + """Dria is not installed. Please install Dria to use this wrapper. + + You can install Dria using the following command: + pip install dria + """ + ) + return + + self.api_key = api_key + self.models = Models + self.contract_id = contract_id + self.top_n = top_n + self.dria_client = Dria(api_key=self.api_key) + if self.contract_id: + self.dria_client.set_contract(self.contract_id) + + def create_knowledge_base( + self, + name: str, + description: str, + category: str, + embedding: str, + ) -> str: + """Create a new knowledge base.""" + contract_id = self.dria_client.create( + name=name, embedding=embedding, category=category, description=description + ) + logger.info(f"Knowledge base created with ID: {contract_id}") + self.contract_id = contract_id + return contract_id + + def insert_data(self, data: List[Dict[str, Any]]) -> str: + """Insert data into the knowledge base.""" + response = self.dria_client.insert_text(data) + logger.info(f"Data inserted: {response}") + return response + + def search(self, query: str) -> List[Dict[str, Any]]: + """Perform a text-based search.""" + results = self.dria_client.search(query, top_n=self.top_n) + logger.info(f"Search results: {results}") + return results + + def query_with_vector(self, vector: List[float]) -> List[Dict[str, Any]]: + """Perform a vector-based query.""" + vector_query_results = self.dria_client.query(vector, top_n=self.top_n) + logger.info(f"Vector query results: {vector_query_results}") + return vector_query_results + + def run(self, query: Union[str, List[float]]) -> Optional[List[Dict[str, Any]]]: + """Method to handle both text-based searches and vector-based queries. + + Args: + query: A string for text-based search or a list of floats for + vector-based query. + + Returns: + The search or query results from Dria. + """ + if isinstance(query, str): + return self.search(query) + elif isinstance(query, list) and all(isinstance(item, float) for item in query): + return self.query_with_vector(query) + else: + logger.error( + """Invalid query type. Please provide a string for text search or a + list of floats for vector query.""" + ) + return None diff --git a/libs/community/tests/integration_tests/retrievers/test_dria_index.py b/libs/community/tests/integration_tests/retrievers/test_dria_index.py new file mode 100644 index 0000000000..9dc683deb4 --- /dev/null +++ b/libs/community/tests/integration_tests/retrievers/test_dria_index.py @@ -0,0 +1,41 @@ +import pytest +from langchain_core.documents import Document + +from langchain_community.retrievers import DriaRetriever + + +# Set a fixture for DriaRetriever +@pytest.fixture +def dria_retriever() -> DriaRetriever: + api_key = "" + contract_id = "B16z9i3rRi0KEeibrzzMU33YTB4WDtos1vdiMBTmKgs" + retriever = DriaRetriever(api_key=api_key, contract_id=contract_id) + return retriever + + +def test_dria_retriever(dria_retriever: DriaRetriever) -> None: + texts = [ + { + "text": "Langchain", + "metadata": { + "source": "source#1", + "document_id": "doc123", + "content": "Langchain", + }, + } + ] + dria_retriever.add_texts(texts) + + # Assuming get_relevant_documents returns a list of Document instances + docs = dria_retriever.get_relevant_documents("Langchain") + + # Perform assertions + assert len(docs) > 0, "Expected at least one document" + doc = docs[0] + assert isinstance(doc, Document), "Expected a Document instance" + assert isinstance(doc.page_content, str), ( + "Expected document content type " "to be string" + ) + assert isinstance( + doc.metadata, dict + ), "Expected document metadata content to be a dictionary" diff --git a/libs/community/tests/unit_tests/retrievers/test_imports.py b/libs/community/tests/unit_tests/retrievers/test_imports.py index d13bfe2881..913d6856e1 100644 --- a/libs/community/tests/unit_tests/retrievers/test_imports.py +++ b/libs/community/tests/unit_tests/retrievers/test_imports.py @@ -10,6 +10,7 @@ EXPECTED_ALL = [ "ChatGPTPluginRetriever", "ChaindeskRetriever", "CohereRagRetriever", + "DriaRetriever", "ElasticSearchBM25Retriever", "EmbedchainRetriever", "GoogleDocumentAIWarehouseRetriever", diff --git a/libs/community/tests/unit_tests/utilities/test_imports.py b/libs/community/tests/unit_tests/utilities/test_imports.py index 3ff5a538b1..e6d4ea9183 100644 --- a/libs/community/tests/unit_tests/utilities/test_imports.py +++ b/libs/community/tests/unit_tests/utilities/test_imports.py @@ -9,6 +9,7 @@ EXPECTED_ALL = [ "BingSearchAPIWrapper", "BraveSearchWrapper", "DuckDuckGoSearchAPIWrapper", + "DriaAPIWrapper", "GoldenQueryAPIWrapper", "GoogleFinanceAPIWrapper", "GoogleJobsAPIWrapper",