core[patch]: Add similarity_score_threshold to VectorStore search types (#22477)

This commit is contained in:
Christophe Bornet 2024-06-04 22:43:55 +02:00 committed by GitHub
parent 9120cf5df2
commit 8ba868d3b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -156,12 +156,17 @@ class VectorStore(ABC):
"""Return docs most similar to query using specified search type.""" """Return docs most similar to query using specified search type."""
if search_type == "similarity": if search_type == "similarity":
return self.similarity_search(query, **kwargs) return self.similarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = self.similarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr": elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs) return self.max_marginal_relevance_search(query, **kwargs)
else: else:
raise ValueError( raise ValueError(
f"search_type of {search_type} not allowed. Expected " f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'." "search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
) )
async def asearch( async def asearch(
@ -170,12 +175,17 @@ class VectorStore(ABC):
"""Return docs most similar to query using specified search type.""" """Return docs most similar to query using specified search type."""
if search_type == "similarity": if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs) return await self.asimilarity_search(query, **kwargs)
elif search_type == "similarity_score_threshold":
docs_and_similarities = await self.asimilarity_search_with_relevance_scores(
query, **kwargs
)
return [doc for doc, _ in docs_and_similarities]
elif search_type == "mmr": elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs) return await self.amax_marginal_relevance_search(query, **kwargs)
else: else:
raise ValueError( raise ValueError(
f"search_type of {search_type} not allowed. Expected " f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'." "search_type to be 'similarity', 'similarity_score_threshold' or 'mmr'."
) )
@abstractmethod @abstractmethod