DocArray as a Retriever (#6031)

## DocArray as a Retriever

[DocArray](https://github.com/docarray/docarray) is an open-source tool
for managing your multi-modal data. It offers flexibility to store and
search through your data using various document index backends. This PR
introduces `DocArrayRetriever` - which works with any available backend
and serves as a retriever for Langchain apps.

Also, I added 2 notebooks:
DocArray Backends - intro to all 5 currently supported backends, how to
initialize, index, and use them as a retriever
DocArray Usage - showcasing what additional search parameters you can
pass to create versatile retrievers

Example:
```python
from docarray.index import InMemoryExactNNIndex
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.retrievers import DocArrayRetriever


# define document schema
class MyDoc(BaseDoc):
    description: str
    description_embedding: NdArray[1536]


embeddings = OpenAIEmbeddings()
# create documents
descriptions = ["description 1", "description 2"]
desc_embeddings = embeddings.embed_documents(texts=descriptions)
docs = DocList[MyDoc](
    [
        MyDoc(description=desc, description_embedding=embedding)
        for desc, embedding in zip(descriptions, desc_embeddings)
    ]
)

# initialize document index with data
db = InMemoryExactNNIndex[MyDoc](docs)

# create a retriever
retriever = DocArrayRetriever(
    index=db,
    embeddings=embeddings,
    search_field="description_embedding",
    content_field="description",
)

# find the relevant document
doc = retriever.get_relevant_documents("action movies")
print(doc)
```

#### Who can review?

@dev2049

---------

Signed-off-by: jupyterjazz <saba.sturua@jina.ai>
This commit is contained in:
Saba Sturua 2023-06-17 18:09:33 +02:00 committed by GitHub
parent 7bb437146d
commit 427551eabf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1263 additions and 0 deletions

View File

@ -0,0 +1,792 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a0eb506a-f52e-4a92-9204-63233c3eb5bd",
"metadata": {},
"source": [
"# DocArray Retriever\n",
"\n",
"[DocArray](https://github.com/docarray/docarray) is a versatile, open-source tool for managing your multi-modal data. It lets you shape your data however you want, and offers the flexibility to store and search it using various document index backends. Plus, it gets even better - you can utilize your DocArray document index to create a DocArrayRetriever, and build awesome Langchain apps!\n",
"\n",
"This notebook is split into two sections. The first section offers an introduction to all five supported document index backends. It provides guidance on setting up and indexing each backend, and also instructs you on how to build a DocArrayRetriever for finding relevant documents. In the second section, we'll select one of these backends and illustrate how to use it through a basic example.\n",
"\n",
"\n",
"[Document Index Backends](#Document-Index-Backends)\n",
"1. [InMemoryExactNNIndex](#inmemoryexactnnindex)\n",
"2. [HnswDocumentIndex](#hnswdocumentindex)\n",
"3. [WeaviateDocumentIndex](#weaviatedocumentindex)\n",
"4. [ElasticDocIndex](#elasticdocindex)\n",
"5. [QdrantDocumentIndex](#qdrantdocumentindex)\n",
"\n",
"[Movie Retrieval using HnswDocumentIndex](#Movie-Retrieval-using-HnswDocumentIndex)\n",
"\n",
"- [Normal Retriever](#normal-retriever)\n",
"- [Retriever with Filters](#retriever-with-filters)\n",
"- [Retriever with MMR Search](#Retriever-with-MMR-search)\n"
]
},
{
"cell_type": "markdown",
"id": "51db6285-58db-481d-8d24-b13d1888056b",
"metadata": {},
"source": [
"# Document Index Backends"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b72a4512-6318-4572-adf2-12b06b2d2e72",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.retrievers import DocArrayRetriever\n",
"from docarray import BaseDoc\n",
"from docarray.typing import NdArray\n",
"import numpy as np\n",
"from langchain.embeddings import FakeEmbeddings\n",
"import random\n",
"\n",
"embeddings = FakeEmbeddings(size=32)"
]
},
{
"cell_type": "markdown",
"id": "bdac41b4-67a1-483f-b3d6-fe662b7bdacd",
"metadata": {},
"source": [
"Before you start building the index, it's important to define your document schema. This determines what fields your documents will have and what type of data each field will hold.\n",
"\n",
"For this demonstration, we'll create a somewhat random schema containing 'title' (str), 'title_embedding' (numpy array), 'year' (int), and 'color' (str)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8a97c56a-63a0-405c-929f-35e1ded79489",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class MyDoc(BaseDoc):\n",
" title: str\n",
" title_embedding: NdArray[32]\n",
" year: int\n",
" color: str"
]
},
{
"cell_type": "markdown",
"id": "297bfdb5-6bfe-47ce-90e7-feefc4c160b7",
"metadata": {
"tags": []
},
"source": [
"## InMemoryExactNNIndex\n",
"\n",
"InMemoryExactNNIndex stores all Documentsin memory. It is a great starting point for small datasets, where you may not want to launch a database server.\n",
"\n",
"Learn more here: https://docs.docarray.org/user_guide/storing/index_in_memory/"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8b6e6343-88c2-4206-92fd-5a634d39da09",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray.index import InMemoryExactNNIndex\n",
"\n",
"\n",
"# initialize the index\n",
"db = InMemoryExactNNIndex[MyDoc]()\n",
"# index data\n",
"db.index(\n",
" [\n",
" MyDoc(\n",
" title=f'My document {i}',\n",
" title_embedding=embeddings.embed_query(f'query {i}'),\n",
" year=i,\n",
" color=random.choice(['red', 'green', 'blue']),\n",
" )\n",
" for i in range(100)\n",
" ]\n",
")\n",
"# optionally, you can create a filter query\n",
"filter_query = {\"year\": {\"$lte\": 90}}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "142060e5-3e0c-4fa2-9f69-8c91f53617f4",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='My document 56', metadata={'id': '1f33e58b6468ab722f3786b96b20afe6', 'year': 56, 'color': 'red'})]\n"
]
}
],
"source": [
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='title_embedding', \n",
" content_field='title',\n",
" filters=filter_query,\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('some query')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "a9daf2c4-6568-4a49-ba6e-21687962d2c1",
"metadata": {},
"source": [
"## HnswDocumentIndex\n",
"\n",
"HnswDocumentIndex is a lightweight Document Index implementation that runs fully locally and is best suited for small- to medium-sized datasets. It stores vectors on disk in [hnswlib](https://github.com/nmslib/hnswlib), and stores all other data in [SQLite](https://www.sqlite.org/index.html).\n",
"\n",
"Learn more here: https://docs.docarray.org/user_guide/storing/index_hnswlib/"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e0be3c00-470f-4448-92cc-3985f5b05809",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray.index import HnswDocumentIndex\n",
"\n",
"\n",
"# initialize the index\n",
"db = HnswDocumentIndex[MyDoc](work_dir='hnsw_index')\n",
"\n",
"# index data\n",
"db.index(\n",
" [\n",
" MyDoc(\n",
" title=f'My document {i}',\n",
" title_embedding=embeddings.embed_query(f'query {i}'),\n",
" year=i,\n",
" color=random.choice(['red', 'green', 'blue']),\n",
" )\n",
" for i in range(100)\n",
" ]\n",
")\n",
"# optionally, you can create a filter query\n",
"filter_query = {\"year\": {\"$lte\": 90}}"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ea9eb5a0-a8f2-465b-81e2-52fb773466cf",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='My document 28', metadata={'id': 'ca9f3f4268eec7c97a7d6e77f541cb82', 'year': 28, 'color': 'red'})]\n"
]
}
],
"source": [
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='title_embedding', \n",
" content_field='title',\n",
" filters=filter_query,\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('some query')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "7177442e-3fd3-4f3d-ab22-cd8265b35112",
"metadata": {},
"source": [
"## WeaviateDocumentIndex\n",
"\n",
"WeaviateDocumentIndex is a document index that is built upon [Weaviate](https://weaviate.io/) vector database.\n",
"\n",
"Learn more here: https://docs.docarray.org/user_guide/storing/index_weaviate/"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "8bcf17ba-8dce-4413-ab4e-61d9baee50e7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# There's a small difference with the Weaviate backend compared to the others. \n",
"# Here, you need to 'mark' the field used for vector search with 'is_embedding=True'. \n",
"# So, let's create a new schema for Weaviate that takes care of this requirement.\n",
"\n",
"from pydantic import Field \n",
"\n",
"class WeaviateDoc(BaseDoc):\n",
" title: str\n",
" title_embedding: NdArray[32] = Field(is_embedding=True)\n",
" year: int\n",
" color: str"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4065dced-3e7e-43d3-8518-b31df1e74383",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray.index import WeaviateDocumentIndex\n",
"\n",
"\n",
"# initialize the index\n",
"dbconfig = WeaviateDocumentIndex.DBConfig(\n",
" host=\"http://localhost:8080\"\n",
")\n",
"db = WeaviateDocumentIndex[WeaviateDoc](db_config=dbconfig)\n",
"\n",
"# index data\n",
"db.index(\n",
" [\n",
" MyDoc(\n",
" title=f'My document {i}',\n",
" title_embedding=embeddings.embed_query(f'query {i}'),\n",
" year=i,\n",
" color=random.choice(['red', 'green', 'blue']),\n",
" )\n",
" for i in range(100)\n",
" ]\n",
")\n",
"# optionally, you can create a filter query\n",
"filter_query = {\"path\": [\"year\"], \"operator\": \"LessThanEqual\", \"valueInt\": \"90\"}"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4e21d124-0f3c-445b-b9fc-dc7c8d6b3d2b",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='My document 17', metadata={'id': '3a5b76e85f0d0a01785dc8f9d965ce40', 'year': 17, 'color': 'red'})]\n"
]
}
],
"source": [
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='title_embedding', \n",
" content_field='title',\n",
" filters=filter_query,\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('some query')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "6ee8f920-9297-4b0a-a353-053a86947d10",
"metadata": {},
"source": [
"## ElasticDocIndex\n",
"\n",
"ElasticDocIndex is a document index that is built upon [ElasticSearch](https://github.com/elastic/elasticsearch)\n",
"\n",
"Learn more here: https://docs.docarray.org/user_guide/storing/index_elastic/"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "92980ead-e4dc-4eef-8618-1c0583f76d7a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray.index import ElasticDocIndex\n",
"\n",
"\n",
"# initialize the index\n",
"db = ElasticDocIndex[MyDoc](\n",
" hosts=\"http://localhost:9200\", \n",
" index_name=\"docarray_retriever\"\n",
")\n",
"\n",
"# index data\n",
"db.index(\n",
" [\n",
" MyDoc(\n",
" title=f'My document {i}',\n",
" title_embedding=embeddings.embed_query(f'query {i}'),\n",
" year=i,\n",
" color=random.choice(['red', 'green', 'blue']),\n",
" )\n",
" for i in range(100)\n",
" ]\n",
")\n",
"# optionally, you can create a filter query\n",
"filter_query = {\"range\": {\"year\": {\"lte\": 90}}}"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8a8e97f3-c3a1-4c7f-b776-363c5e7dd69d",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='My document 46', metadata={'id': 'edbc721bac1c2ad323414ad1301528a4', 'year': 46, 'color': 'green'})]\n"
]
}
],
"source": [
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='title_embedding', \n",
" content_field='title',\n",
" filters=filter_query,\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('some query')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "281432f8-87a5-4f22-a582-9d5dac33d158",
"metadata": {},
"source": [
"## QdrantDocumentIndex\n",
"\n",
"QdrantDocumentIndex is a document index that is build upon [Qdrant](https://qdrant.tech/) vector database\n",
"\n",
"Learn more here: https://docs.docarray.org/user_guide/storing/index_qdrant/"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "b6fd91d0-630a-4974-bdf1-6dfa4d1a68f5",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:Payload indexes have no effect in the local Qdrant. Please use server Qdrant if you need payload indexes.\n"
]
}
],
"source": [
"from docarray.index import QdrantDocumentIndex\n",
"from qdrant_client.http import models as rest\n",
"\n",
"\n",
"# initialize the index\n",
"qdrant_config = QdrantDocumentIndex.DBConfig(path=\":memory:\")\n",
"db = QdrantDocumentIndex[MyDoc](qdrant_config)\n",
"\n",
"# index data\n",
"db.index(\n",
" [\n",
" MyDoc(\n",
" title=f'My document {i}',\n",
" title_embedding=embeddings.embed_query(f'query {i}'),\n",
" year=i,\n",
" color=random.choice(['red', 'green', 'blue']),\n",
" )\n",
" for i in range(100)\n",
" ]\n",
")\n",
"# optionally, you can create a filter query\n",
"filter_query = rest.Filter(\n",
" must=[\n",
" rest.FieldCondition(\n",
" key=\"year\",\n",
" range=rest.Range(\n",
" gte=10,\n",
" lt=90,\n",
" ),\n",
" )\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "a6dd6460-7175-48ee-8cfb-9a0abf35ec13",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='My document 80', metadata={'id': '97465f98d0810f1f330e4ecc29b13d20', 'year': 80, 'color': 'blue'})]\n"
]
}
],
"source": [
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='title_embedding', \n",
" content_field='title',\n",
" filters=filter_query,\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('some query')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "3afb65b0-c620-411a-855f-1aa81481bdbb",
"metadata": {},
"source": [
"# Movie Retrieval using HnswDocumentIndex"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "07b71d96-381e-4965-b525-af9f7cc5f86c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"movies = [\n",
" {\n",
" \"title\": \"Inception\",\n",
" \"description\": \"A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.\",\n",
" \"director\": \"Christopher Nolan\",\n",
" \"rating\": 8.8,\n",
" },\n",
" {\n",
" \"title\": \"The Dark Knight\",\n",
" \"description\": \"When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of the greatest psychological and physical tests of his ability to fight injustice.\",\n",
" \"director\": \"Christopher Nolan\",\n",
" \"rating\": 9.0,\n",
" },\n",
" {\n",
" \"title\": \"Interstellar\",\n",
" \"description\": \"Interstellar explores the boundaries of human exploration as a group of astronauts venture through a wormhole in space. In their quest to ensure the survival of humanity, they confront the vastness of space-time and grapple with love and sacrifice.\",\n",
" \"director\": \"Christopher Nolan\",\n",
" \"rating\": 8.6,\n",
" },\n",
" {\n",
" \"title\": \"Pulp Fiction\",\n",
" \"description\": \"The lives of two mob hitmen, a boxer, a gangster's wife, and a pair of diner bandits intertwine in four tales of violence and redemption.\",\n",
" \"director\": \"Quentin Tarantino\",\n",
" \"rating\": 8.9,\n",
" },\n",
" {\n",
" \"title\": \"Reservoir Dogs\",\n",
" \"description\": \"When a simple jewelry heist goes horribly wrong, the surviving criminals begin to suspect that one of them is a police informant.\",\n",
" \"director\": \"Quentin Tarantino\",\n",
" \"rating\": 8.3,\n",
" },\n",
" {\n",
" \"title\": \"The Godfather\",\n",
" \"description\": \"An aging patriarch of an organized crime dynasty transfers control of his empire to his reluctant son.\",\n",
" \"director\": \"Francis Ford Coppola\",\n",
" \"rating\": 9.2,\n",
" },\n",
"]\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1860edfb-936d-4cd8-a167-e8f9c4617709",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"OpenAI API Key: ········\n"
]
}
],
"source": [
"import getpass\n",
"import os \n",
"\n",
"os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0538541d-26ea-4323-96b9-47768c75dcd8",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray import BaseDoc, DocList\n",
"from docarray.typing import NdArray\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"\n",
"# define schema for your movie documents\n",
"class MyDoc(BaseDoc):\n",
" title: str\n",
" description: str\n",
" description_embedding: NdArray[1536]\n",
" rating: float\n",
" director: str\n",
" \n",
"\n",
"embeddings = OpenAIEmbeddings()\n",
"\n",
"\n",
"# get \"description\" embeddings, and create documents\n",
"docs = DocList[MyDoc](\n",
" [\n",
" MyDoc(\n",
" description_embedding=embeddings.embed_query(movie[\"description\"]), **movie\n",
" )\n",
" for movie in movies\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f5ae1b41-0372-47ea-89bb-c6ad968a2919",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from docarray.index import HnswDocumentIndex\n",
"\n",
"# initialize the index\n",
"db = HnswDocumentIndex[MyDoc](work_dir='movie_search')\n",
"\n",
"# add data\n",
"db.index(docs)"
]
},
{
"cell_type": "markdown",
"id": "9ca3f91b-ed11-490b-b60a-0d1d9b50a5b2",
"metadata": {
"tags": []
},
"source": [
"## Normal Retriever"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "efdb5cbf-218e-48a6-af0f-25b7a510e343",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'})]\n"
]
}
],
"source": [
"from langchain.retrievers import DocArrayRetriever\n",
"\n",
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='description_embedding', \n",
" content_field='description'\n",
")\n",
"\n",
"# find the relevant document\n",
"doc = retriever.get_relevant_documents('movie about dreams')\n",
"print(doc)"
]
},
{
"cell_type": "markdown",
"id": "3defa711-51df-4b48-b02a-306706cfacd0",
"metadata": {},
"source": [
"## Retriever with Filters"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "205a9fe8-13bb-4280-9485-f6973bbc6943",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content='Interstellar explores the boundaries of human exploration as a group of astronauts venture through a wormhole in space. In their quest to ensure the survival of humanity, they confront the vastness of space-time and grapple with love and sacrifice.', metadata={'id': 'ab704cc7ae8573dc617f9a5e25df022a', 'title': 'Interstellar', 'rating': 8.6, 'director': 'Christopher Nolan'}), Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'})]\n"
]
}
],
"source": [
"from langchain.retrievers import DocArrayRetriever\n",
"\n",
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='description_embedding', \n",
" content_field='description',\n",
" filters={'director': {'$eq': 'Christopher Nolan'}},\n",
" top_k=2,\n",
")\n",
"\n",
"# find relevant documents\n",
"docs = retriever.get_relevant_documents('space travel')\n",
"print(docs)"
]
},
{
"cell_type": "markdown",
"id": "fa10afa6-1554-4c2b-8afc-cff44e32d2f8",
"metadata": {},
"source": [
"## Retriever with MMR search"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b7305599-b166-419c-8e1e-8ff7c247cce6",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Document(page_content=\"The lives of two mob hitmen, a boxer, a gangster's wife, and a pair of diner bandits intertwine in four tales of violence and redemption.\", metadata={'id': 'e6aa313bbde514e23fbc80ab34511afd', 'title': 'Pulp Fiction', 'rating': 8.9, 'director': 'Quentin Tarantino'}), Document(page_content='A thief who steals corporate secrets through the use of dream-sharing technology is given the task of planting an idea into the mind of a CEO.', metadata={'id': 'f1649d5b6776db04fec9a116bbb6bbe5', 'title': 'Inception', 'rating': 8.8, 'director': 'Christopher Nolan'}), Document(page_content='When the menace known as the Joker wreaks havoc and chaos on the people of Gotham, Batman must accept one of the greatest psychological and physical tests of his ability to fight injustice.', metadata={'id': '91dec17d4272041b669fd113333a65f7', 'title': 'The Dark Knight', 'rating': 9.0, 'director': 'Christopher Nolan'})]\n"
]
}
],
"source": [
"from langchain.retrievers import DocArrayRetriever\n",
"\n",
"# create a retriever\n",
"retriever = DocArrayRetriever(\n",
" index=db, \n",
" embeddings=embeddings, \n",
" search_field='description_embedding', \n",
" content_field='description',\n",
" filters={'rating': {'$gte': 8.7}},\n",
" search_type='mmr',\n",
" top_k=3,\n",
")\n",
"\n",
"# find relevant documents\n",
"docs = retriever.get_relevant_documents('action movies')\n",
"print(docs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4865cf25-48af-4d60-9337-9528b9b30f28",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -4,6 +4,7 @@ from langchain.retrievers.azure_cognitive_search import AzureCognitiveSearchRetr
from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain.retrievers.databerry import DataberryRetriever
from langchain.retrievers.docarray import DocArrayRetriever
from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever
from langchain.retrievers.knn import KNNRetriever
from langchain.retrievers.merger_retriever import MergerRetriever
@ -44,4 +45,5 @@ __all__ = [
"WeaviateHybridSearchRetriever",
"WikipediaRetriever",
"ZepRetriever",
"DocArrayRetriever",
]

View File

@ -0,0 +1,203 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance
class SearchType(str, Enum):
similarity = "similarity"
mmr = "mmr"
class DocArrayRetriever(BaseRetriever, BaseModel):
"""
Retriever class for DocArray Document Indices.
Currently, supports 5 backends:
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
ElasticDocIndex, and WeaviateDocumentIndex.
Attributes:
index: One of the above-mentioned index instances
embeddings: Embedding model to represent text as vectors
search_field: Field to consider for searching in the documents.
Should be an embedding/vector/tensor.
content_field: Field that represents the main content in your document schema.
Will be used as a `page_content`. Everything else will go into `metadata`.
search_type: Type of search to perform (similarity / mmr)
filters: Filters applied for document retrieval.
top_k: Number of documents to return
"""
index: Any
embeddings: Embeddings
search_field: str
content_field: str
search_type: SearchType = SearchType.similarity
top_k: int = 1
filters: Optional[Any] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
Args:
query: string to find relevant documents for
Returns:
List of relevant documents
"""
query_emb = np.array(self.embeddings.embed_query(query))
if self.search_type == SearchType.similarity:
results = self._similarity_search(query_emb)
elif self.search_type == SearchType.mmr:
results = self._mmr_search(query_emb)
else:
raise ValueError(
f"Search type {self.search_type} does not exist. "
f"Choose either 'similarity' or 'mmr'."
)
return results
def _search(
self, query_emb: np.ndarray, top_k: int
) -> List[Union[Dict[str, Any], Any]]:
"""
Perform a search using the query embedding and return top_k documents.
Args:
query_emb: Query represented as an embedding
top_k: Number of documents to return
Returns:
A list of top_k documents matching the query
"""
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
filter_args = {}
search_field = self.search_field
if isinstance(self.index, WeaviateDocumentIndex):
filter_args["where_filter"] = self.filters
search_field = ""
elif isinstance(self.index, ElasticDocIndex):
filter_args["query"] = self.filters
else:
filter_args["filter_query"] = self.filters
if self.filters:
query = (
self.index.build_query() # get empty query object
.find(
query=query_emb, search_field=search_field
) # add vector similarity search
.filter(**filter_args) # add filter search
.build(limit=top_k) # build the query
)
# execute the combined query and return the results
docs = self.index.execute_query(query)
if hasattr(docs, "documents"):
docs = docs.documents
docs = docs[:top_k]
else:
docs = self.index.find(
query=query_emb, search_field=search_field, limit=top_k
).documents
return docs
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a similarity search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of documents most similar to the query
"""
docs = self._search(query_emb=query_emb, top_k=self.top_k)
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
return results
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
"""
Perform a maximal marginal relevance (mmr) search.
Args:
query_emb: Query represented as an embedding
Returns:
A list of diverse documents related to the query
"""
docs = self._search(query_emb=query_emb, top_k=20)
mmr_selected = maximal_marginal_relevance(
query_emb,
[
doc[self.search_field]
if isinstance(doc, dict)
else getattr(doc, self.search_field)
for doc in docs
],
k=self.top_k,
)
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
return results
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
"""
Convert a DocArray document (which also might be a dict)
to a langchain document format.
DocArray document can contain arbitrary fields, so the mapping is done
in the following way:
page_content <-> content_field
metadata <-> all other fields excluding
tensors and embeddings (so float, int, string)
Args:
doc: DocArray document
Returns:
Document in langchain format
Raises:
ValueError: If the document doesn't contain the content field
"""
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
if self.content_field not in fields:
raise ValueError(
f"Document does not contain the content field - {self.content_field}."
)
lc_doc = Document(
page_content=doc[self.content_field]
if isinstance(doc, dict)
else getattr(doc, self.content_field)
)
for name in fields:
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
if (
isinstance(value, (str, int, float, bool))
and name != self.content_field
):
lc_doc.metadata[name] = value
return lc_doc
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError

View File

@ -0,0 +1,195 @@
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
import numpy as np
import pytest
from docarray import BaseDoc
from docarray.index import (
ElasticDocIndex,
HnswDocumentIndex,
InMemoryExactNNIndex,
QdrantDocumentIndex,
WeaviateDocumentIndex,
)
from docarray.typing import NdArray
from pydantic import Field
from qdrant_client.http import models as rest
from langchain.embeddings import FakeEmbeddings
class MyDoc(BaseDoc):
title: str
title_embedding: NdArray[32] # type: ignore
other_emb: NdArray[32] # type: ignore
year: int
class WeaviateDoc(BaseDoc):
# When initializing the Weaviate index, denote the field
# you want to search on with `is_embedding=True`
title: str
title_embedding: NdArray[32] = Field(is_embedding=True) # type: ignore
other_emb: NdArray[32] # type: ignore
year: int
@pytest.fixture
def init_weaviate() -> (
Generator[
Tuple[WeaviateDocumentIndex[WeaviateDoc], Dict[str, Any], FakeEmbeddings],
None,
None,
]
):
"""
cd tests/integration_tests/vectorstores/docker-compose
docker compose -f weaviate.yml up
"""
embeddings = FakeEmbeddings(size=32)
# initialize WeaviateDocumentIndex
dbconfig = WeaviateDocumentIndex.DBConfig(host="http://localhost:8080")
weaviate_db = WeaviateDocumentIndex[WeaviateDoc](
db_config=dbconfig, index_name="docarray_retriever"
)
# index data
weaviate_db.index(
[
WeaviateDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"path": ["year"], "operator": "LessThanEqual", "valueInt": "90"}
yield weaviate_db, filter_query, embeddings
weaviate_db._client.schema.delete_all()
@pytest.fixture
def init_elastic() -> (
Generator[Tuple[ElasticDocIndex[MyDoc], Dict[str, Any], FakeEmbeddings], None, None]
):
"""
cd tests/integration_tests/vectorstores/docker-compose
docker-compose -f elasticsearch.yml up
"""
embeddings = FakeEmbeddings(size=32)
# initialize ElasticDocIndex
elastic_db = ElasticDocIndex[MyDoc](
hosts="http://localhost:9200", index_name="docarray_retriever"
)
# index data
elastic_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"range": {"year": {"lte": 90}}}
yield elastic_db, filter_query, embeddings
elastic_db._client.indices.delete(index="docarray_retriever")
@pytest.fixture
def init_qdrant() -> Tuple[QdrantDocumentIndex[MyDoc], rest.Filter, FakeEmbeddings]:
embeddings = FakeEmbeddings(size=32)
# initialize QdrantDocumentIndex
qdrant_config = QdrantDocumentIndex.DBConfig(path=":memory:")
qdrant_db = QdrantDocumentIndex[MyDoc](qdrant_config)
# index data
qdrant_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = rest.Filter(
must=[
rest.FieldCondition(
key="year",
range=rest.Range(
gte=10,
lt=90,
),
)
]
)
return qdrant_db, filter_query, embeddings
@pytest.fixture
def init_in_memory() -> (
Tuple[InMemoryExactNNIndex[MyDoc], Dict[str, Any], FakeEmbeddings]
):
embeddings = FakeEmbeddings(size=32)
# initialize InMemoryExactNNIndex
in_memory_db = InMemoryExactNNIndex[MyDoc]()
# index data
in_memory_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"year": {"$lte": 90}}
return in_memory_db, filter_query, embeddings
@pytest.fixture
def init_hnsw(
tmp_path: Path,
) -> Tuple[HnswDocumentIndex[MyDoc], Dict[str, Any], FakeEmbeddings]:
embeddings = FakeEmbeddings(size=32)
# initialize InMemoryExactNNIndex
hnsw_db = HnswDocumentIndex[MyDoc](work_dir=tmp_path)
# index data
hnsw_db.index(
[
MyDoc(
title=f"My document {i}",
title_embedding=np.array(embeddings.embed_query(f"fake emb {i}")),
other_emb=np.array(embeddings.embed_query(f"other fake emb {i}")),
year=i,
)
for i in range(100)
]
)
# build a filter query
filter_query = {"year": {"$lte": 90}}
return hnsw_db, filter_query, embeddings

View File

@ -0,0 +1,71 @@
from typing import Any
import pytest
from vcr.request import Request
from langchain.retrievers import DocArrayRetriever
from tests.integration_tests.retrievers.docarray.fixtures import ( # noqa: F401
init_elastic,
init_hnsw,
init_in_memory,
init_qdrant,
init_weaviate,
)
@pytest.mark.parametrize(
"backend",
["init_hnsw", "init_in_memory", "init_qdrant", "init_elastic", "init_weaviate"],
)
def test_backends(request: Request, backend: Any) -> None:
index, filter_query, embeddings = request.getfixturevalue(backend)
# create a retriever
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
content_field="title",
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
# create a retriever with filters
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
content_field="title",
filters=filter_query,
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
assert docs[0].metadata["year"] <= 90
# create a retriever with MMR search
retriever = DocArrayRetriever(
index=index,
embeddings=embeddings,
search_field="title_embedding",
search_type="mmr",
content_field="title",
filters=filter_query,
)
docs = retriever.get_relevant_documents("my docs")
assert len(docs) == 1
assert "My document" in docs[0].page_content
assert "id" in docs[0].metadata and "year" in docs[0].metadata
assert "other_emb" not in docs[0].metadata
assert docs[0].metadata["year"] <= 90