adding max_marginal_relevance_search method to MongoDBAtlasVectorSearch (#7310)

Adding a maximal_marginal_relevance method to the
MongoDBAtlasVectorSearch vectorstore enhances the user experience by
providing more diverse search results

Issue: #7304
pull/7487/head
Paul-Emile Brotons 1 year ago committed by GitHub
parent 04cddfba0d
commit d2cf0d16b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,10 +1,14 @@
{
"cells":[
{
"attachments": {},
"attachments":{
},
"cell_type":"markdown",
"id":"683953b3",
"metadata": {},
"metadata":{
},
"source":[
"# MongoDB Atlas\n",
"\n",
@ -23,9 +27,13 @@
"execution_count":null,
"id":"b4c41cad-08ef-4f72-a545-2151e4598efe",
"metadata":{
"tags": []
"tags":[
]
},
"outputs": [],
"outputs":[
],
"source":[
"!pip install pymongo"
]
@ -34,21 +42,28 @@
"cell_type":"code",
"execution_count":null,
"id":"c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
"metadata": {},
"outputs": [],
"metadata":{
},
"outputs":[
],
"source":[
"import os\n",
"import getpass\n",
"\n",
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n",
"MONGODB_ATLAS_CLUSTER_URI = os.environ[\"MONGODB_ATLAS_CLUSTER_URI\"]"
"MONGODB_ATLAS_CLUSTER_URI = getpass.getpass(\"MongoDB Atlas Cluster URI:\")\n"
]
},
{
"attachments": {},
"attachments":{
},
"cell_type":"markdown",
"id":"457ace44-1d95-4001-9dd5-78811ab208ad",
"metadata": {},
"metadata":{
},
"source":[
"We want to use `OpenAIEmbeddings` so we need to set up our OpenAI API Key. "
]
@ -57,18 +72,25 @@
"cell_type":"code",
"execution_count":null,
"id":"2d8f240d",
"metadata": {},
"outputs": [],
"metadata":{
},
"outputs":[
],
"source":[
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n",
"OPENAI_API_KEY = os.environ[\"OPENAI_API_KEY\"]"
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")\n"
]
},
{
"attachments": {},
"attachments":{
},
"cell_type":"markdown",
"id":"1f3ecc42",
"metadata": {},
"metadata":{
},
"source":[
"Now, let's create a vector search index on your cluster. In the below example, `embedding` is the name of the field that contains the embedding vector. Please refer to the [documentation](https://www.mongodb.com/docs/atlas/atlas-search/define-field-mappings-for-vector-search) to get more details on how to define an Atlas Vector Search index.\n",
"You can name the index `langchain_demo` and create the index on the namespace `lanchain_db.langchain_col`. Finally, write the following definition in the JSON editor on MongoDB Atlas:\n",
@ -94,23 +116,17 @@
"execution_count":2,
"id":"aac9563e",
"metadata":{
"tags": []
"tags":[
]
},
"outputs": [],
"outputs":[
],
"source":[
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
"from langchain.document_loaders import TextLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a3c3999a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
@ -125,8 +141,12 @@
"cell_type":"code",
"execution_count":null,
"id":"6e104aee",
"metadata": {},
"outputs": [],
"metadata":{
},
"outputs":[
],
"source":[
"from pymongo import MongoClient\n",
"\n",
@ -152,51 +172,48 @@
"cell_type":"code",
"execution_count":null,
"id":"9c608226",
"metadata": {},
"outputs": [],
"metadata":{
},
"outputs":[
],
"source":[
"print(docs[0].page_content)"
]
},
{
"attachments": {},
"attachments":{
},
"cell_type":"markdown",
"id":"851a2ec9-9390-49a4-8412-3e132c9f789d",
"metadata": {},
"metadata":{
},
"source":[
"You can reuse the vector search index you created, make sure the `OPENAI_API_KEY` environment variable is set up, then execute another query."
"You can also instantiate the vector store directly and execute a query as follows:"
]
},
{
"cell_type":"code",
"execution_count":null,
"id":"6336fe79-3e73-48be-b20a-0ff1bb6a4399",
"metadata": {},
"outputs": [],
"metadata":{
},
"outputs":[
],
"source":[
"from pymongo import MongoClient\n",
"from langchain.vectorstores import MongoDBAtlasVectorSearch\n",
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
"import os\n",
"\n",
"MONGODB_ATLAS_URI = os.environ[\"MONGODB_ATLAS_URI\"]\n",
"\n",
"# initialize MongoDB python client\n",
"client = MongoClient(MONGODB_ATLAS_URI)\n",
"\n",
"db_name = \"langchain_db\"\n",
"collection_name = \"langchain_col\"\n",
"collection = client[db_name][collection_name]\n",
"index_name = \"langchain_demo\"\n",
"\n",
"# initialize vector store\n",
"vectorStore = MongoDBAtlasVectorSearch(\n",
"vectorstore = MongoDBAtlasVectorSearch(\n",
" collection, OpenAIEmbeddings(), index_name=index_name\n",
")\n",
"\n",
"# perform a similarity search between a query and the ingested documents\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = vectorStore.similarity_search(query)\n",
"docs = vectorstore.similarity_search(query)\n",
"\n",
"print(docs[0].page_content)"
]

@ -14,9 +14,12 @@ from typing import (
Union,
)
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from pymongo.collection import Collection
@ -137,6 +140,39 @@ class MongoDBAtlasVectorSearch(VectorStore):
insert_result = self._collection.insert_many(to_insert)
return insert_result.inserted_ids
def _similarity_search_with_score(
self,
embedding: List[float],
k: int = 4,
pre_filter: Optional[dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
) -> List[Tuple[Document, float]]:
knn_beta = {
"vector": embedding,
"path": self._embedding_key,
"k": k,
}
if pre_filter:
knn_beta["filter"] = pre_filter
pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$set": {"score": {"$meta": "searchScore"}}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline)
docs = []
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
docs.append((Document(page_content=text, metadata=res), score))
return docs
def similarity_search_with_score(
self,
query: str,
@ -165,30 +201,13 @@ class MongoDBAtlasVectorSearch(VectorStore):
Returns:
List of Documents most similar to the query and score for each
"""
knn_beta = {
"vector": self._embedding.embed_query(query),
"path": self._embedding_key,
"k": k,
}
if pre_filter:
knn_beta["filter"] = pre_filter
pipeline = [
{
"$search": {
"index": self._index_name,
"knnBeta": knn_beta,
}
},
{"$project": {"score": {"$meta": "searchScore"}, self._embedding_key: 0}},
]
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline)
docs = []
for res in cursor:
text = res.pop(self._text_key)
score = res.pop("score")
docs.append((Document(page_content=text, metadata=res), score))
embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
embedding,
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
)
return docs
def similarity_search(
@ -227,6 +246,53 @@ class MongoDBAtlasVectorSearch(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
pre_filter: Optional[dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Optional Number of Documents to return. Defaults to 4.
fetch_k: Optional Number of Documents to fetch before passing to MMR
algorithm. Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
pre_filter: Optional Dictionary of argument(s) to prefilter on document
fields.
post_filter_pipeline: Optional Pipeline of MongoDB aggregation stages
following the knnBeta search.
Returns:
List of Documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
query_embedding,
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(query_embedding),
[doc.metadata[self._embedding_key] for doc, _ in docs],
k=k,
lambda_mult=lambda_mult,
)
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
return mmr_docs
@classmethod
def from_texts(
cls,
@ -252,7 +318,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
from langchain.vectorstores import MongoDBAtlasVectorSearch
from langchain.embeddings import OpenAIEmbeddings
client = MongoClient("<YOUR-CONNECTION-STRING>")
mongo_client = MongoClient("<YOUR-CONNECTION-STRING>")
collection = mongo_client["<db_name>"]["<collection_name>"]
embeddings = OpenAIEmbeddings()
vectorstore = MongoDBAtlasVectorSearch.from_texts(
@ -264,6 +330,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
"""
if collection is None:
raise ValueError("Must provide 'collection' named parameter.")
vecstore = cls(collection, embedding, **kwargs)
vecstore.add_texts(texts, metadatas=metadatas)
return vecstore
vectorstore = cls(collection, embedding, **kwargs)
vectorstore.add_texts(texts, metadatas=metadatas)
return vectorstore

@ -119,3 +119,18 @@ class TestMongoDBAtlasVectorSearch:
"Sandwich", k=1, pre_filter={"range": {"lte": 0, "path": "c"}}
)
assert output == []
def test_mmr(self, embedding_openai: Embeddings) -> None:
texts = ["foo", "foo", "fou", "foy"]
vectorstore = MongoDBAtlasVectorSearch.from_texts(
texts,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
sleep(1) # waits for mongot to update Lucene's index
query = "foo"
output = vectorstore.max_marginal_relevance_search(query, k=10, lambda_mult=0.1)
assert len(output) == len(texts)
assert output[0].page_content == "foo"
assert output[1].page_content != "foo"

Loading…
Cancel
Save