From ce89b34fc0c5a5a2c390c90ca4db2b20bca37320 Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Mon, 29 Apr 2024 18:11:44 +0200 Subject: [PATCH] community[patch]: support hybrid search with threshold in Azure AI Search Retriever (#20907) Support hybrid search with a score threshold -- similar to what we do for similarity search. --- .../vectorstores/azuresearch.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 906bb9d00d..5d154fa210 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -8,6 +8,8 @@ from typing import ( TYPE_CHECKING, Any, Callable, + ClassVar, + Collection, Dict, Iterable, List, @@ -519,6 +521,17 @@ class AzureSearch(VectorStore): ] return docs + def hybrid_search_with_relevance_scores( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: + score_threshold = kwargs.pop("score_threshold", None) + result = self.hybrid_search_with_score(query, k=k, **kwargs) + return ( + result + if score_threshold is None + else [r for r in result if r[1] >= score_threshold] + ) + def semantic_hybrid_search( self, query: str, k: int = 4, **kwargs: Any ) -> List[Document]: @@ -687,9 +700,16 @@ class AzureSearchVectorStoreRetriever(BaseRetriever): """Azure Search instance used to find similar documents.""" search_type: str = "hybrid" """Type of search to perform. Options are "similarity", "hybrid", - "semantic_hybrid".""" + "semantic_hybrid", "similarity_score_threshold", "hybrid_score_threshold".""" k: int = 4 """Number of documents to return.""" + allowed_search_types: ClassVar[Collection[str]] = ( + "similarity", + "similarity_score_threshold", + "hybrid", + "hybrid_score_threshold", + "semantic_hybrid", + ) class Config: """Configuration for this pydantic object.""" @@ -701,17 +721,10 @@ class AzureSearchVectorStoreRetriever(BaseRetriever): """Validate search type.""" if "search_type" in values: search_type = values["search_type"] - if search_type not in ( - allowed_search_types := ( - "similarity", - "similarity_score_threshold", - "hybrid", - "semantic_hybrid", - ) - ): + if search_type not in cls.allowed_search_types: raise ValueError( f"search_type of {search_type} not allowed. Valid values are: " - f"{allowed_search_types}" + f"{cls.allowed_search_types}" ) return values @@ -732,6 +745,13 @@ class AzureSearchVectorStoreRetriever(BaseRetriever): ] elif self.search_type == "hybrid": docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs) + elif self.search_type == "hybrid_score_threshold": + docs = [ + doc + for doc, _ in self.vectorstore.hybrid_search_with_relevance_scores( + query, k=self.k, **kwargs + ) + ] elif self.search_type == "semantic_hybrid": docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs) else: