From 427551eabf32e0c9fa4428dcfad5fed86f99bbdf Mon Sep 17 00:00:00 2001 From: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com> Date: Sat, 17 Jun 2023 18:09:33 +0200 Subject: [PATCH] DocArray as a Retriever (#6031) ## DocArray as a Retriever [DocArray](https://github.com/docarray/docarray) is an open-source tool for managing your multi-modal data. It offers flexibility to store and search through your data using various document index backends. This PR introduces `DocArrayRetriever` - which works with any available backend and serves as a retriever for Langchain apps. Also, I added 2 notebooks: DocArray Backends - intro to all 5 currently supported backends, how to initialize, index, and use them as a retriever DocArray Usage - showcasing what additional search parameters you can pass to create versatile retrievers Example: ```python from docarray.index import InMemoryExactNNIndex from docarray import BaseDoc, DocList from docarray.typing import NdArray from langchain.embeddings.openai import OpenAIEmbeddings from langchain.retrievers import DocArrayRetriever # define document schema class MyDoc(BaseDoc): description: str description_embedding: NdArray[1536] embeddings = OpenAIEmbeddings() # create documents descriptions = ["description 1", "description 2"] desc_embeddings = embeddings.embed_documents(texts=descriptions) docs = DocList[MyDoc]( [ MyDoc(description=desc, description_embedding=embedding) for desc, embedding in zip(descriptions, desc_embeddings) ] ) # initialize document index with data db = InMemoryExactNNIndex[MyDoc](docs) # create a retriever retriever = DocArrayRetriever( index=db, embeddings=embeddings, search_field="description_embedding", content_field="description", ) # find the relevant document doc = retriever.get_relevant_documents("action movies") print(doc) ``` #### Who can review? @dev2049 --------- Signed-off-by: jupyterjazz --- .../integrations/docarray_retriever.ipynb | 792 ++++++++++++++++++ langchain/retrievers/__init__.py | 2 + langchain/retrievers/docarray.py | 203 +++++ .../retrievers/docarray/__init__.py | 0 .../retrievers/docarray/fixtures.py | 195 +++++ .../retrievers/docarray/test_backends.py | 71 ++ 6 files changed, 1263 insertions(+) create mode 100644 docs/extras/modules/data_connection/retrievers/integrations/docarray_retriever.ipynb create mode 100644 langchain/retrievers/docarray.py create mode 100644 tests/integration_tests/retrievers/docarray/__init__.py create mode 100644 tests/integration_tests/retrievers/docarray/fixtures.py create mode 100644 tests/integration_tests/retrievers/docarray/test_backends.py diff --git a/docs/extras/modules/data_connection/retrievers/integrations/docarray_retriever.ipynb b/docs/extras/modules/data_connection/retrievers/integrations/docarray_retriever.ipynb new file mode 100644 index 00000000..a3462753 --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/integrations/docarray_retriever.ipynb @@ -0,0 +1,792 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a0eb506a-f52e-4a92-9204-63233c3eb5bd", + "metadata": {}, + "source": [ + "# DocArray Retriever\n", + "\n", + "[DocArray](https://github.com/docarray/docarray) is a versatile, open-source tool for managing your multi-modal data. It lets you shape your data however you want, and offers the flexibility to store and search it using various document index backends. Plus, it gets even better - you can utilize your DocArray document index to create a DocArrayRetriever, and build awesome Langchain apps!\n", + "\n", + "This notebook is split into two sections. The first section offers an introduction to all five supported document index backends. It provides guidance on setting up and indexing each backend, and also instructs you on how to build a DocArrayRetriever for finding relevant documents. In the second section, we'll select one of these backends and illustrate how to use it through a basic example.\n", + "\n", + "\n", + "[Document Index Backends](#Document-Index-Backends)\n", + "1. [InMemoryExactNNIndex](#inmemoryexactnnindex)\n", + "2. [HnswDocumentIndex](#hnswdocumentindex)\n", + "3. [WeaviateDocumentIndex](#weaviatedocumentindex)\n", + "4. [ElasticDocIndex](#elasticdocindex)\n", + "5. [QdrantDocumentIndex](#qdrantdocumentindex)\n", + "\n", + "[Movie Retrieval using HnswDocumentIndex](#Movie-Retrieval-using-HnswDocumentIndex)\n", + "\n", + "- [Normal Retriever](#normal-retriever)\n", + "- [Retriever with Filters](#retriever-with-filters)\n", + "- [Retriever with MMR Search](#Retriever-with-MMR-search)\n" + ] + }, + { + "cell_type": "markdown", + "id": "51db6285-58db-481d-8d24-b13d1888056b", + "metadata": {}, + "source": [ + "# Document Index Backends" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b72a4512-6318-4572-adf2-12b06b2d2e72", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.retrievers import DocArrayRetriever\n", + "from docarray import BaseDoc\n", + "from docarray.typing import NdArray\n", + "import numpy as np\n", + "from langchain.embeddings import FakeEmbeddings\n", + "import random\n", + "\n", + "embeddings = FakeEmbeddings(size=32)" + ] + }, + { + "cell_type": "markdown", + "id": "bdac41b4-67a1-483f-b3d6-fe662b7bdacd", + "metadata": {}, + "source": [ + "Before you start building the index, it's important to define your document schema. This determines what fields your documents will have and what type of data each field will hold.\n", + "\n", + "For this demonstration, we'll create a somewhat random schema containing 'title' (str), 'title_embedding' (numpy array), 'year' (int), and 'color' (str)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "8a97c56a-63a0-405c-929f-35e1ded79489", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class MyDoc(BaseDoc):\n", + " title: str\n", + " title_embedding: NdArray[32]\n", + " year: int\n", + " color: str" + ] + }, + { + "cell_type": "markdown", + "id": "297bfdb5-6bfe-47ce-90e7-feefc4c160b7", + "metadata": { + "tags": [] + }, + "source": [ + "## InMemoryExactNNIndex\n", + "\n", + "InMemoryExactNNIndex stores all Documentsin memory. It is a great starting point for small datasets, where you may not want to launch a database server.\n", + "\n", + "Learn more here: https://docs.docarray.org/user_guide/storing/index_in_memory/" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8b6e6343-88c2-4206-92fd-5a634d39da09", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray.index import InMemoryExactNNIndex\n", + "\n", + "\n", + "# initialize the index\n", + "db = InMemoryExactNNIndex[MyDoc]()\n", + "# index data\n", + "db.index(\n", + " [\n", + " MyDoc(\n", + " title=f'My document {i}',\n", + " title_embedding=embeddings.embed_query(f'query {i}'),\n", + " year=i,\n", + " color=random.choice(['red', 'green', 'blue']),\n", + " )\n", + " for i in range(100)\n", + " ]\n", + ")\n", + "# optionally, you can create a filter query\n", + "filter_query = {\"year\": {\"$lte\": 90}}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "142060e5-3e0c-4fa2-9f69-8c91f53617f4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='My document 56', metadata={'id': '1f33e58b6468ab722f3786b96b20afe6', 'year': 56, 'color': 'red'})]\n" + ] + } + ], + "source": [ + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='title_embedding', \n", + " content_field='title',\n", + " filters=filter_query,\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('some query')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "a9daf2c4-6568-4a49-ba6e-21687962d2c1", + "metadata": {}, + "source": [ + "## HnswDocumentIndex\n", + "\n", + "HnswDocumentIndex is a lightweight Document Index implementation that runs fully locally and is best suited for small- to medium-sized datasets. It stores vectors on disk in [hnswlib](https://github.com/nmslib/hnswlib), and stores all other data in [SQLite](https://www.sqlite.org/index.html).\n", + "\n", + "Learn more here: https://docs.docarray.org/user_guide/storing/index_hnswlib/" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e0be3c00-470f-4448-92cc-3985f5b05809", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray.index import HnswDocumentIndex\n", + "\n", + "\n", + "# initialize the index\n", + "db = HnswDocumentIndex[MyDoc](work_dir='hnsw_index')\n", + "\n", + "# index data\n", + "db.index(\n", + " [\n", + " MyDoc(\n", + " title=f'My document {i}',\n", + " title_embedding=embeddings.embed_query(f'query {i}'),\n", + " year=i,\n", + " color=random.choice(['red', 'green', 'blue']),\n", + " )\n", + " for i in range(100)\n", + " ]\n", + ")\n", + "# optionally, you can create a filter query\n", + "filter_query = {\"year\": {\"$lte\": 90}}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ea9eb5a0-a8f2-465b-81e2-52fb773466cf", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='My document 28', metadata={'id': 'ca9f3f4268eec7c97a7d6e77f541cb82', 'year': 28, 'color': 'red'})]\n" + ] + } + ], + "source": [ + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='title_embedding', \n", + " content_field='title',\n", + " filters=filter_query,\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('some query')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "7177442e-3fd3-4f3d-ab22-cd8265b35112", + "metadata": {}, + "source": [ + "## WeaviateDocumentIndex\n", + "\n", + "WeaviateDocumentIndex is a document index that is built upon [Weaviate](https://weaviate.io/) vector database.\n", + "\n", + "Learn more here: https://docs.docarray.org/user_guide/storing/index_weaviate/" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8bcf17ba-8dce-4413-ab4e-61d9baee50e7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# There's a small difference with the Weaviate backend compared to the others. \n", + "# Here, you need to 'mark' the field used for vector search with 'is_embedding=True'. \n", + "# So, let's create a new schema for Weaviate that takes care of this requirement.\n", + "\n", + "from pydantic import Field \n", + "\n", + "class WeaviateDoc(BaseDoc):\n", + " title: str\n", + " title_embedding: NdArray[32] = Field(is_embedding=True)\n", + " year: int\n", + " color: str" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4065dced-3e7e-43d3-8518-b31df1e74383", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray.index import WeaviateDocumentIndex\n", + "\n", + "\n", + "# initialize the index\n", + "dbconfig = WeaviateDocumentIndex.DBConfig(\n", + " host=\"http://localhost:8080\"\n", + ")\n", + "db = WeaviateDocumentIndex[WeaviateDoc](db_config=dbconfig)\n", + "\n", + "# index data\n", + "db.index(\n", + " [\n", + " MyDoc(\n", + " title=f'My document {i}',\n", + " title_embedding=embeddings.embed_query(f'query {i}'),\n", + " year=i,\n", + " color=random.choice(['red', 'green', 'blue']),\n", + " )\n", + " for i in range(100)\n", + " ]\n", + ")\n", + "# optionally, you can create a filter query\n", + "filter_query = {\"path\": [\"year\"], \"operator\": \"LessThanEqual\", \"valueInt\": \"90\"}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4e21d124-0f3c-445b-b9fc-dc7c8d6b3d2b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='My document 17', metadata={'id': '3a5b76e85f0d0a01785dc8f9d965ce40', 'year': 17, 'color': 'red'})]\n" + ] + } + ], + "source": [ + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='title_embedding', \n", + " content_field='title',\n", + " filters=filter_query,\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('some query')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "6ee8f920-9297-4b0a-a353-053a86947d10", + "metadata": {}, + "source": [ + "## ElasticDocIndex\n", + "\n", + "ElasticDocIndex is a document index that is built upon [ElasticSearch](https://github.com/elastic/elasticsearch)\n", + "\n", + "Learn more here: https://docs.docarray.org/user_guide/storing/index_elastic/" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "92980ead-e4dc-4eef-8618-1c0583f76d7a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray.index import ElasticDocIndex\n", + "\n", + "\n", + "# initialize the index\n", + "db = ElasticDocIndex[MyDoc](\n", + " hosts=\"http://localhost:9200\", \n", + " index_name=\"docarray_retriever\"\n", + ")\n", + "\n", + "# index data\n", + "db.index(\n", + " [\n", + " MyDoc(\n", + " title=f'My document {i}',\n", + " title_embedding=embeddings.embed_query(f'query {i}'),\n", + " year=i,\n", + " color=random.choice(['red', 'green', 'blue']),\n", + " )\n", + " for i in range(100)\n", + " ]\n", + ")\n", + "# optionally, you can create a filter query\n", + "filter_query = {\"range\": {\"year\": {\"lte\": 90}}}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8a8e97f3-c3a1-4c7f-b776-363c5e7dd69d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='My document 46', metadata={'id': 'edbc721bac1c2ad323414ad1301528a4', 'year': 46, 'color': 'green'})]\n" + ] + } + ], + "source": [ + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='title_embedding', \n", + " content_field='title',\n", + " filters=filter_query,\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('some query')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "281432f8-87a5-4f22-a582-9d5dac33d158", + "metadata": {}, + "source": [ + "## QdrantDocumentIndex\n", + "\n", + "QdrantDocumentIndex is a document index that is build upon [Qdrant](https://qdrant.tech/) vector database\n", + "\n", + "Learn more here: https://docs.docarray.org/user_guide/storing/index_qdrant/" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b6fd91d0-630a-4974-bdf1-6dfa4d1a68f5", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.\n" + ] + } + ], + "source": [ + "from docarray.index import QdrantDocumentIndex\n", + "from qdrant_client.http import models as rest\n", + "\n", + "\n", + "# initialize the index\n", + "qdrant_config = QdrantDocumentIndex.DBConfig(path=\":memory:\")\n", + "db = QdrantDocumentIndex[MyDoc](qdrant_config)\n", + "\n", + "# index data\n", + "db.index(\n", + " [\n", + " MyDoc(\n", + " title=f'My document {i}',\n", + " title_embedding=embeddings.embed_query(f'query {i}'),\n", + " year=i,\n", + " color=random.choice(['red', 'green', 'blue']),\n", + " )\n", + " for i in range(100)\n", + " ]\n", + ")\n", + "# optionally, you can create a filter query\n", + "filter_query = rest.Filter(\n", + " must=[\n", + " rest.FieldCondition(\n", + " key=\"year\",\n", + " range=rest.Range(\n", + " gte=10,\n", + " lt=90,\n", + " ),\n", + " )\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a6dd6460-7175-48ee-8cfb-9a0abf35ec13", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='My document 80', metadata={'id': '97465f98d0810f1f330e4ecc29b13d20', 'year': 80, 'color': 'blue'})]\n" + ] + } + ], + "source": [ + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='title_embedding', \n", + " content_field='title',\n", + " filters=filter_query,\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('some query')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "3afb65b0-c620-411a-855f-1aa81481bdbb", + "metadata": {}, + "source": [ + "# Movie Retrieval using HnswDocumentIndex" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "07b71d96-381e-4965-b525-af9f7cc5f86c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "movies = [\n", + " {\n", + " \"title\": \"Inception\",\n", + " \"description\": \"A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.\",\n", + " \"director\": \"Christopher Nolan\",\n", + " \"rating\": 8.8,\n", + " },\n", + " {\n", + " \"title\": \"The Dark Knight\",\n", + " \"description\": \"When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of the greatest psychological and physical tests of his ability to fight injustice.\",\n", + " \"director\": \"Christopher Nolan\",\n", + " \"rating\": 9.0,\n", + " },\n", + " {\n", + " \"title\": \"Interstellar\",\n", + " \"description\": \"Interstellar explores the boundaries of human exploration as a group of astronauts venture through a wormhole in space. In their quest to ensure the survival of humanity, they confront the vastness of space-time and grapple with love and sacrifice.\",\n", + " \"director\": \"Christopher Nolan\",\n", + " \"rating\": 8.6,\n", + " },\n", + " {\n", + " \"title\": \"Pulp Fiction\",\n", + " \"description\": \"The lives of two mob hitmen, a boxer, a gangster's wife, and a pair of diner bandits intertwine in four tales of violence and redemption.\",\n", + " \"director\": \"Quentin Tarantino\",\n", + " \"rating\": 8.9,\n", + " },\n", + " {\n", + " \"title\": \"Reservoir Dogs\",\n", + " \"description\": \"When a simple jewelry heist goes horribly wrong, the surviving criminals begin to suspect that one of them is a police informant.\",\n", + " \"director\": \"Quentin Tarantino\",\n", + " \"rating\": 8.3,\n", + " },\n", + " {\n", + " \"title\": \"The Godfather\",\n", + " \"description\": \"An aging patriarch of an organized crime dynasty transfers control of his empire to his reluctant son.\",\n", + " \"director\": \"Francis Ford Coppola\",\n", + " \"rating\": 9.2,\n", + " },\n", + "]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1860edfb-936d-4cd8-a167-e8f9c4617709", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + "OpenAI API Key: ········\n" + ] + } + ], + "source": [ + "import getpass\n", + "import os \n", + "\n", + "os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "0538541d-26ea-4323-96b9-47768c75dcd8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray import BaseDoc, DocList\n", + "from docarray.typing import NdArray\n", + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "\n", + "# define schema for your movie documents\n", + "class MyDoc(BaseDoc):\n", + " title: str\n", + " description: str\n", + " description_embedding: NdArray[1536]\n", + " rating: float\n", + " director: str\n", + " \n", + "\n", + "embeddings = OpenAIEmbeddings()\n", + "\n", + "\n", + "# get \"description\" embeddings, and create documents\n", + "docs = DocList[MyDoc](\n", + " [\n", + " MyDoc(\n", + " description_embedding=embeddings.embed_query(movie[\"description\"]), **movie\n", + " )\n", + " for movie in movies\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "f5ae1b41-0372-47ea-89bb-c6ad968a2919", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from docarray.index import HnswDocumentIndex\n", + "\n", + "# initialize the index\n", + "db = HnswDocumentIndex[MyDoc](work_dir='movie_search')\n", + "\n", + "# add data\n", + "db.index(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "9ca3f91b-ed11-490b-b60a-0d1d9b50a5b2", + "metadata": { + "tags": [] + }, + "source": [ + "## Normal Retriever" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "efdb5cbf-218e-48a6-af0f-25b7a510e343", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'})]\n" + ] + } + ], + "source": [ + "from langchain.retrievers import DocArrayRetriever\n", + "\n", + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='description_embedding', \n", + " content_field='description'\n", + ")\n", + "\n", + "# find the relevant document\n", + "doc = retriever.get_relevant_documents('movie about dreams')\n", + "print(doc)" + ] + }, + { + "cell_type": "markdown", + "id": "3defa711-51df-4b48-b02a-306706cfacd0", + "metadata": {}, + "source": [ + "## Retriever with Filters" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "205a9fe8-13bb-4280-9485-f6973bbc6943", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content='Interstellar explores the boundaries of human exploration as a group of astronauts venture through a wormhole in space. In their quest to ensure the survival of humanity, they confront the vastness of space-time and grapple with love and sacrifice.', metadata={'id': 'ab704cc7ae8573dc617f9a5e25df022a', 'title': 'Interstellar', 'rating': 8.6, 'director': 'Christopher Nolan'}), Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'})]\n" + ] + } + ], + "source": [ + "from langchain.retrievers import DocArrayRetriever\n", + "\n", + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='description_embedding', \n", + " content_field='description',\n", + " filters={'director': {'$eq': 'Christopher Nolan'}},\n", + " top_k=2,\n", + ")\n", + "\n", + "# find relevant documents\n", + "docs = retriever.get_relevant_documents('space travel')\n", + "print(docs)" + ] + }, + { + "cell_type": "markdown", + "id": "fa10afa6-1554-4c2b-8afc-cff44e32d2f8", + "metadata": {}, + "source": [ + "## Retriever with MMR search" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "b7305599-b166-419c-8e1e-8ff7c247cce6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Document(page_content=\"The lives of two mob hitmen, a boxer, a gangster's wife, and a pair of diner bandits intertwine in four tales of violence and redemption.\", metadata={'id': 'e6aa313bbde514e23fbc80ab34511afd', 'title': 'Pulp Fiction', 'rating': 8.9, 'director': 'Quentin Tarantino'}), Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'}), Document(page_content='When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of the greatest psychological and physical tests of his ability to fight injustice.', metadata={'id': '91dec17d4272041b669fd113333a65f7', 'title': 'The Dark Knight', 'rating': 9.0, 'director': 'Christopher Nolan'})]\n" + ] + } + ], + "source": [ + "from langchain.retrievers import DocArrayRetriever\n", + "\n", + "# create a retriever\n", + "retriever = DocArrayRetriever(\n", + " index=db, \n", + " embeddings=embeddings, \n", + " search_field='description_embedding', \n", + " content_field='description',\n", + " filters={'rating': {'$gte': 8.7}},\n", + " search_type='mmr',\n", + " top_k=3,\n", + ")\n", + "\n", + "# find relevant documents\n", + "docs = retriever.get_relevant_documents('action movies')\n", + "print(docs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4865cf25-48af-4d60-9337-9528b9b30f28", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 19b67f92..1b2847ec 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -4,6 +4,7 @@ from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetr from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.databerry import DataberryRetriever +from langchain.retrievers.docarray import DocArrayRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.knn import KNNRetriever from langchain.retrievers.merger_retriever import MergerRetriever @@ -44,4 +45,5 @@ __all__ = [ "WeaviateHybridSearchRetriever", "WikipediaRetriever", "ZepRetriever", + "DocArrayRetriever", ] diff --git a/langchain/retrievers/docarray.py b/langchain/retrievers/docarray.py new file mode 100644 index 00000000..57b5eaaf --- /dev/null +++ b/langchain/retrievers/docarray.py @@ -0,0 +1,203 @@ +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from pydantic import BaseModel + +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document +from langchain.vectorstores.utils import maximal_marginal_relevance + + +class SearchType(str, Enum): + similarity = "similarity" + mmr = "mmr" + + +class DocArrayRetriever(BaseRetriever, BaseModel): + """ + Retriever class for DocArray Document Indices. + + Currently, supports 5 backends: + InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex, + ElasticDocIndex, and WeaviateDocumentIndex. + + Attributes: + index: One of the above-mentioned index instances + embeddings: Embedding model to represent text as vectors + search_field: Field to consider for searching in the documents. + Should be an embedding/vector/tensor. + content_field: Field that represents the main content in your document schema. + Will be used as a `page_content`. Everything else will go into `metadata`. + search_type: Type of search to perform (similarity / mmr) + filters: Filters applied for document retrieval. + top_k: Number of documents to return + """ + + index: Any + embeddings: Embeddings + search_field: str + content_field: str + search_type: SearchType = SearchType.similarity + top_k: int = 1 + filters: Optional[Any] = None + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def get_relevant_documents(self, query: str) -> List[Document]: + """Get documents relevant for a query. + + Args: + query: string to find relevant documents for + + Returns: + List of relevant documents + """ + query_emb = np.array(self.embeddings.embed_query(query)) + + if self.search_type == SearchType.similarity: + results = self._similarity_search(query_emb) + elif self.search_type == SearchType.mmr: + results = self._mmr_search(query_emb) + else: + raise ValueError( + f"Search type {self.search_type} does not exist. " + f"Choose either 'similarity' or 'mmr'." + ) + + return results + + def _search( + self, query_emb: np.ndarray, top_k: int + ) -> List[Union[Dict[str, Any], Any]]: + """ + Perform a search using the query embedding and return top_k documents. + + Args: + query_emb: Query represented as an embedding + top_k: Number of documents to return + + Returns: + A list of top_k documents matching the query + """ + + from docarray.index import ElasticDocIndex, WeaviateDocumentIndex + + filter_args = {} + search_field = self.search_field + if isinstance(self.index, WeaviateDocumentIndex): + filter_args["where_filter"] = self.filters + search_field = "" + elif isinstance(self.index, ElasticDocIndex): + filter_args["query"] = self.filters + else: + filter_args["filter_query"] = self.filters + + if self.filters: + query = ( + self.index.build_query() # get empty query object + .find( + query=query_emb, search_field=search_field + ) # add vector similarity search + .filter(**filter_args) # add filter search + .build(limit=top_k) # build the query + ) + # execute the combined query and return the results + docs = self.index.execute_query(query) + if hasattr(docs, "documents"): + docs = docs.documents + docs = docs[:top_k] + else: + docs = self.index.find( + query=query_emb, search_field=search_field, limit=top_k + ).documents + return docs + + def _similarity_search(self, query_emb: np.ndarray) -> List[Document]: + """ + Perform a similarity search. + + Args: + query_emb: Query represented as an embedding + + Returns: + A list of documents most similar to the query + """ + docs = self._search(query_emb=query_emb, top_k=self.top_k) + results = [self._docarray_to_langchain_doc(doc) for doc in docs] + return results + + def _mmr_search(self, query_emb: np.ndarray) -> List[Document]: + """ + Perform a maximal marginal relevance (mmr) search. + + Args: + query_emb: Query represented as an embedding + + Returns: + A list of diverse documents related to the query + """ + docs = self._search(query_emb=query_emb, top_k=20) + + mmr_selected = maximal_marginal_relevance( + query_emb, + [ + doc[self.search_field] + if isinstance(doc, dict) + else getattr(doc, self.search_field) + for doc in docs + ], + k=self.top_k, + ) + results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected] + return results + + def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document: + """ + Convert a DocArray document (which also might be a dict) + to a langchain document format. + + DocArray document can contain arbitrary fields, so the mapping is done + in the following way: + + page_content <-> content_field + metadata <-> all other fields excluding + tensors and embeddings (so float, int, string) + + Args: + doc: DocArray document + + Returns: + Document in langchain format + + Raises: + ValueError: If the document doesn't contain the content field + """ + + fields = doc.keys() if isinstance(doc, dict) else doc.__fields__ + + if self.content_field not in fields: + raise ValueError( + f"Document does not contain the content field - {self.content_field}." + ) + lc_doc = Document( + page_content=doc[self.content_field] + if isinstance(doc, dict) + else getattr(doc, self.content_field) + ) + + for name in fields: + value = doc[name] if isinstance(doc, dict) else getattr(doc, name) + if ( + isinstance(value, (str, int, float, bool)) + and name != self.content_field + ): + lc_doc.metadata[name] = value + + return lc_doc + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError diff --git a/tests/integration_tests/retrievers/docarray/__init__.py b/tests/integration_tests/retrievers/docarray/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/retrievers/docarray/fixtures.py b/tests/integration_tests/retrievers/docarray/fixtures.py new file mode 100644 index 00000000..28e0b61c --- /dev/null +++ b/tests/integration_tests/retrievers/docarray/fixtures.py @@ -0,0 +1,195 @@ +from pathlib import Path +from typing import Any, Dict, Generator, Tuple + +import numpy as np +import pytest +from docarray import BaseDoc +from docarray.index import ( + ElasticDocIndex, + HnswDocumentIndex, + InMemoryExactNNIndex, + QdrantDocumentIndex, + WeaviateDocumentIndex, +) +from docarray.typing import NdArray +from pydantic import Field +from qdrant_client.http import models as rest + +from langchain.embeddings import FakeEmbeddings + + +class MyDoc(BaseDoc): + title: str + title_embedding: NdArray[32] # type: ignore + other_emb: NdArray[32] # type: ignore + year: int + + +class WeaviateDoc(BaseDoc): + # When initializing the Weaviate index, denote the field + # you want to search on with `is_embedding=True` + title: str + title_embedding: NdArray[32] = Field(is_embedding=True) # type: ignore + other_emb: NdArray[32] # type: ignore + year: int + + +@pytest.fixture +def init_weaviate() -> ( + Generator[ + Tuple[WeaviateDocumentIndex[WeaviateDoc], Dict[str, Any], FakeEmbeddings], + None, + None, + ] +): + """ + cd tests/integration_tests/vectorstores/docker-compose + docker compose -f weaviate.yml up + """ + embeddings = FakeEmbeddings(size=32) + + # initialize WeaviateDocumentIndex + dbconfig = WeaviateDocumentIndex.DBConfig(host="http://localhost:8080") + weaviate_db = WeaviateDocumentIndex[WeaviateDoc]( + db_config=dbconfig, index_name="docarray_retriever" + ) + + # index data + weaviate_db.index( + [ + WeaviateDoc( + title=f"My document {i}", + title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")), + other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")), + year=i, + ) + for i in range(100) + ] + ) + # build a filter query + filter_query = {"path": ["year"], "operator": "LessThanEqual", "valueInt": "90"} + + yield weaviate_db, filter_query, embeddings + + weaviate_db._client.schema.delete_all() + + +@pytest.fixture +def init_elastic() -> ( + Generator[Tuple[ElasticDocIndex[MyDoc], Dict[str, Any], FakeEmbeddings], None, None] +): + """ + cd tests/integration_tests/vectorstores/docker-compose + docker-compose -f elasticsearch.yml up + """ + embeddings = FakeEmbeddings(size=32) + + # initialize ElasticDocIndex + elastic_db = ElasticDocIndex[MyDoc]( + hosts="http://localhost:9200", index_name="docarray_retriever" + ) + # index data + elastic_db.index( + [ + MyDoc( + title=f"My document {i}", + title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")), + other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")), + year=i, + ) + for i in range(100) + ] + ) + # build a filter query + filter_query = {"range": {"year": {"lte": 90}}} + + yield elastic_db, filter_query, embeddings + + elastic_db._client.indices.delete(index="docarray_retriever") + + +@pytest.fixture +def init_qdrant() -> Tuple[QdrantDocumentIndex[MyDoc], rest.Filter, FakeEmbeddings]: + embeddings = FakeEmbeddings(size=32) + + # initialize QdrantDocumentIndex + qdrant_config = QdrantDocumentIndex.DBConfig(path=":memory:") + qdrant_db = QdrantDocumentIndex[MyDoc](qdrant_config) + # index data + qdrant_db.index( + [ + MyDoc( + title=f"My document {i}", + title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")), + other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")), + year=i, + ) + for i in range(100) + ] + ) + # build a filter query + filter_query = rest.Filter( + must=[ + rest.FieldCondition( + key="year", + range=rest.Range( + gte=10, + lt=90, + ), + ) + ] + ) + + return qdrant_db, filter_query, embeddings + + +@pytest.fixture +def init_in_memory() -> ( + Tuple[InMemoryExactNNIndex[MyDoc], Dict[str, Any], FakeEmbeddings] +): + embeddings = FakeEmbeddings(size=32) + + # initialize InMemoryExactNNIndex + in_memory_db = InMemoryExactNNIndex[MyDoc]() + # index data + in_memory_db.index( + [ + MyDoc( + title=f"My document {i}", + title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")), + other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")), + year=i, + ) + for i in range(100) + ] + ) + # build a filter query + filter_query = {"year": {"$lte": 90}} + + return in_memory_db, filter_query, embeddings + + +@pytest.fixture +def init_hnsw( + tmp_path: Path, +) -> Tuple[HnswDocumentIndex[MyDoc], Dict[str, Any], FakeEmbeddings]: + embeddings = FakeEmbeddings(size=32) + + # initialize InMemoryExactNNIndex + hnsw_db = HnswDocumentIndex[MyDoc](work_dir=tmp_path) + # index data + hnsw_db.index( + [ + MyDoc( + title=f"My document {i}", + title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")), + other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")), + year=i, + ) + for i in range(100) + ] + ) + # build a filter query + filter_query = {"year": {"$lte": 90}} + + return hnsw_db, filter_query, embeddings diff --git a/tests/integration_tests/retrievers/docarray/test_backends.py b/tests/integration_tests/retrievers/docarray/test_backends.py new file mode 100644 index 00000000..b1f34b25 --- /dev/null +++ b/tests/integration_tests/retrievers/docarray/test_backends.py @@ -0,0 +1,71 @@ +from typing import Any + +import pytest +from vcr.request import Request + +from langchain.retrievers import DocArrayRetriever +from tests.integration_tests.retrievers.docarray.fixtures import ( # noqa: F401 + init_elastic, + init_hnsw, + init_in_memory, + init_qdrant, + init_weaviate, +) + + +@pytest.mark.parametrize( + "backend", + ["init_hnsw", "init_in_memory", "init_qdrant", "init_elastic", "init_weaviate"], +) +def test_backends(request: Request, backend: Any) -> None: + index, filter_query, embeddings = request.getfixturevalue(backend) + + # create a retriever + retriever = DocArrayRetriever( + index=index, + embeddings=embeddings, + search_field="title_embedding", + content_field="title", + ) + + docs = retriever.get_relevant_documents("my docs") + + assert len(docs) == 1 + assert "My document" in docs[0].page_content + assert "id" in docs[0].metadata and "year" in docs[0].metadata + assert "other_emb" not in docs[0].metadata + + # create a retriever with filters + retriever = DocArrayRetriever( + index=index, + embeddings=embeddings, + search_field="title_embedding", + content_field="title", + filters=filter_query, + ) + + docs = retriever.get_relevant_documents("my docs") + + assert len(docs) == 1 + assert "My document" in docs[0].page_content + assert "id" in docs[0].metadata and "year" in docs[0].metadata + assert "other_emb" not in docs[0].metadata + assert docs[0].metadata["year"] <= 90 + + # create a retriever with MMR search + retriever = DocArrayRetriever( + index=index, + embeddings=embeddings, + search_field="title_embedding", + search_type="mmr", + content_field="title", + filters=filter_query, + ) + + docs = retriever.get_relevant_documents("my docs") + + assert len(docs) == 1 + assert "My document" in docs[0].page_content + assert "id" in docs[0].metadata and "year" in docs[0].metadata + assert "other_emb" not in docs[0].metadata + assert docs[0].metadata["year"] <= 90