mongodb[patch]: Remove embedding retrieval from mongodb payload (#19035)

## Description
Returning the embedding is not necessary in the vector search
functionality unless specified as a debugging step. This change defaults
the behavior such that the server _only_ returns the embedding key if
explicitly requested, such as in the case of
`max_marginal_relevance_search`.


- [x] **Add tests and docs**: If you're adding a new integration, please
include
* Added `test_from_documents_no_embedding_return`


- [x] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jib 2024-03-18 15:43:50 -04:00 committed by GitHub
parent 366ba77459
commit 866d6408af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 11 deletions

View File

@ -183,6 +183,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
params = {
"queryVector": embedding,
@ -199,6 +201,11 @@ class MongoDBAtlasVectorSearch(VectorStore):
query,
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
# Exclude the embedding key from the return payload
if not include_embedding:
pipeline.append({"$project": {self._embedding_key: 0}})
if post_filter_pipeline is not None:
pipeline.extend(post_filter_pipeline)
cursor = self._collection.aggregate(pipeline) # type: ignore[arg-type]
@ -215,6 +222,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k: int = 4,
pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return MongoDB documents most similar to the given query and their scores.
@ -238,6 +246,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
return docs
@ -271,6 +280,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
if additional and "similarity_score" in additional:
@ -310,20 +320,15 @@ class MongoDBAtlasVectorSearch(VectorStore):
List of documents selected by maximal marginal relevance.
"""
query_embedding = self._embedding.embed_query(query)
docs = self._similarity_search_with_score(
query_embedding,
k=fetch_k,
return self.max_marginal_relevance_search_by_vector(
embedding=query_embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(query_embedding),
[doc.metadata[self._embedding_key] for doc, _ in docs],
k=k,
lambda_mult=lambda_mult,
)
mmr_docs = [docs[i][0] for i in mmr_doc_indexes]
return mmr_docs
@classmethod
def from_texts(
@ -433,6 +438,8 @@ class MongoDBAtlasVectorSearch(VectorStore):
k=fetch_k,
pre_filter=pre_filter,
post_filter_pipeline=post_filter_pipeline,
include_embedding=kwargs.pop("include_embedding", True),
**kwargs,
)
mmr_doc_indexes = maximal_marginal_relevance(
np.array(embedding),

View File

@ -91,6 +91,52 @@ class TestMongoDBAtlasVectorSearch:
# Check for the presence of the metadata key
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_documents_no_embedding_return(
self, embedding_openai: Embeddings, collection: Any
) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Document(page_content="Cats have fluff.", metadata={"b": 1}),
Document(page_content="What is a sandwich?", metadata={"c": 1}),
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1)
assert len(output) == 1
# Check for presence of embedding in each document
assert all(["embedding" not in key.metadata for key in output])
# Check for the presence of the metadata key
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_documents_embedding_return(
self, embedding_openai: Embeddings, collection: Any
) -> None:
"""Test end to end construction and search."""
documents = [
Document(page_content="Dogs are tough.", metadata={"a": 1}),
Document(page_content="Cats have fluff.", metadata={"b": 1}),
Document(page_content="What is a sandwich?", metadata={"c": 1}),
Document(page_content="That fence is purple.", metadata={"d": 1, "e": 2}),
]
vectorstore = PatchedMongoDBAtlasVectorSearch.from_documents(
documents,
embedding_openai,
collection=collection,
index_name=INDEX_NAME,
)
output = vectorstore.similarity_search("Sandwich", k=1, include_embedding=True)
assert len(output) == 1
# Check for presence of embedding in each document
assert all([key.metadata.get("embedding") for key in output])
# Check for the presence of the metadata key
assert any([key.page_content == output[0].page_content for key in documents])
def test_from_texts(self, embedding_openai: Embeddings, collection: Any) -> None:
texts = [
"Dogs are tough.",