community[patch]: support hybrid search with threshold in Azure AI Search Retriever (#20907)

Support hybrid search with a score threshold -- similar to what we do
for similarity search.
This commit is contained in:
Massimiliano Pronesti 2024-04-29 18:11:44 +02:00 committed by GitHub
parent b3efa38cc0
commit ce89b34fc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,6 +8,8 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Callable, Callable,
ClassVar,
Collection,
Dict, Dict,
Iterable, Iterable,
List, List,
@ -519,6 +521,17 @@ class AzureSearch(VectorStore):
] ]
return docs return docs
def hybrid_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
score_threshold = kwargs.pop("score_threshold", None)
result = self.hybrid_search_with_score(query, k=k, **kwargs)
return (
result
if score_threshold is None
else [r for r in result if r[1] >= score_threshold]
)
def semantic_hybrid_search( def semantic_hybrid_search(
self, query: str, k: int = 4, **kwargs: Any self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]: ) -> List[Document]:
@ -687,9 +700,16 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Azure Search instance used to find similar documents.""" """Azure Search instance used to find similar documents."""
search_type: str = "hybrid" search_type: str = "hybrid"
"""Type of search to perform. Options are "similarity", "hybrid", """Type of search to perform. Options are "similarity", "hybrid",
"semantic_hybrid".""" "semantic_hybrid", "similarity_score_threshold", "hybrid_score_threshold"."""
k: int = 4 k: int = 4
"""Number of documents to return.""" """Number of documents to return."""
allowed_search_types: ClassVar[Collection[str]] = (
"similarity",
"similarity_score_threshold",
"hybrid",
"hybrid_score_threshold",
"semantic_hybrid",
)
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -701,17 +721,10 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Validate search type.""" """Validate search type."""
if "search_type" in values: if "search_type" in values:
search_type = values["search_type"] search_type = values["search_type"]
if search_type not in ( if search_type not in cls.allowed_search_types:
allowed_search_types := (
"similarity",
"similarity_score_threshold",
"hybrid",
"semantic_hybrid",
)
):
raise ValueError( raise ValueError(
f"search_type of {search_type} not allowed. Valid values are: " f"search_type of {search_type} not allowed. Valid values are: "
f"{allowed_search_types}" f"{cls.allowed_search_types}"
) )
return values return values
@ -732,6 +745,13 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
] ]
elif self.search_type == "hybrid": elif self.search_type == "hybrid":
docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs) docs = self.vectorstore.hybrid_search(query, k=self.k, **kwargs)
elif self.search_type == "hybrid_score_threshold":
docs = [
doc
for doc, _ in self.vectorstore.hybrid_search_with_relevance_scores(
query, k=self.k, **kwargs
)
]
elif self.search_type == "semantic_hybrid": elif self.search_type == "semantic_hybrid":
docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs) docs = self.vectorstore.semantic_hybrid_search(query, k=self.k, **kwargs)
else: else: