community[patch]: Implement similarity_score_threshold for MongoDB Vector Store (#14740)

Adds the option for `similarity_score_threshold` when using
`MongoDBAtlasVectorSearch` as a vector store retriever.

Example use:

```
vector_search = MongoDBAtlasVectorSearch.from_documents(...)

qa_retriever = vector_search.as_retriever(
    search_type="similarity_score_threshold",
    search_kwargs={
        "score_threshold": 0.5,
    }
)

qa = RetrievalQA.from_chain_type(
	llm=OpenAI(), 
	chain_type="stuff", 
	retriever=qa_retriever,
)

docs = qa({"query": "..."})
```

I've tested this feature locally, using a MongoDB Atlas Cluster with a
vector search index.
This commit is contained in:
Noah Stapp 2023-12-15 16:49:21 -08:00 committed by GitHub
parent dcead816df
commit 34e6f3ff72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,7 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -60,6 +61,7 @@ class MongoDBAtlasVectorSearch(VectorStore):
index_name: str = "default", index_name: str = "default",
text_key: str = "text", text_key: str = "text",
embedding_key: str = "embedding", embedding_key: str = "embedding",
relevance_score_fn: str = "cosine",
): ):
""" """
Args: Args:
@ -70,17 +72,32 @@ class MongoDBAtlasVectorSearch(VectorStore):
embedding_key: MongoDB field that will contain the embedding for embedding_key: MongoDB field that will contain the embedding for
each document. each document.
index_name: Name of the Atlas Search index. index_name: Name of the Atlas Search index.
relevance_score_fn: The similarity score used for the index.
Currently supported: Euclidean, cosine, and dot product.
""" """
self._collection = collection self._collection = collection
self._embedding = embedding self._embedding = embedding
self._index_name = index_name self._index_name = index_name
self._text_key = text_key self._text_key = text_key
self._embedding_key = embedding_key self._embedding_key = embedding_key
self._relevance_score_fn = relevance_score_fn
@property @property
def embeddings(self) -> Embeddings: def embeddings(self) -> Embeddings:
return self._embedding return self._embedding
def _select_relevance_score_fn(self) -> Callable[[float], float]:
if self._relevance_score_fn == "euclidean":
return self._euclidean_relevance_score_fn
elif self._relevance_score_fn == "dotProduct":
return self._max_inner_product_relevance_score_fn
elif self._relevance_score_fn == "cosine":
return self._cosine_relevance_score_fn
else:
raise NotImplementedError(
f"No relevance score function for ${self._relevance_score_fn}"
)
@classmethod @classmethod
def from_connection_string( def from_connection_string(
cls, cls,
@ -198,7 +215,6 @@ class MongoDBAtlasVectorSearch(VectorStore):
def similarity_search_with_score( def similarity_search_with_score(
self, self,
query: str, query: str,
*,
k: int = 4, k: int = 4,
pre_filter: Optional[Dict] = None, pre_filter: Optional[Dict] = None,
post_filter_pipeline: Optional[List[Dict]] = None, post_filter_pipeline: Optional[List[Dict]] = None,