community[patch]: AzureSearch incorrectly converted to retriever (#20601)

Closes #20600.

Please see the issue for more details.
pull/20012/head^2
Massimiliano Pronesti 3 months ago committed by GitHub
parent 520ef24fb9
commit 2542a09abc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -654,6 +654,31 @@ class AzureSearch(VectorStore):
azure_search.add_texts(texts, metadatas, **kwargs)
return azure_search
def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore
"""Return AzureSearchVectorStoreRetriever initialized from this VectorStore.
Args:
search_type (Optional[str]): Defines the type of search that
the Retriever should perform.
Can be "similarity" (default), "hybrid", or
"semantic_hybrid".
search_kwargs (Optional[Dict]): Keyword arguments to pass to the
search function. Can include things like:
k: Amount of documents to return (Default: 4)
score_threshold: Minimum relevance threshold
for similarity_score_threshold
fetch_k: Amount of documents to pass to MMR algorithm (Default: 20)
lambda_mult: Diversity of results returned by MMR;
1 for minimum diversity and 0 for maximum. (Default: 0.5)
filter: Filter by document metadata
Returns:
AzureSearchVectorStoreRetriever: Retriever class for VectorStore.
"""
tags = kwargs.pop("tags", None) or []
tags.extend(self._get_retriever_tags())
return AzureSearchVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Retriever that uses `Azure Cognitive Search`."""
@ -676,8 +701,13 @@ class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Validate search type."""
if "search_type" in values:
search_type = values["search_type"]
if search_type not in ("similarity", "hybrid", "semantic_hybrid"):
raise ValueError(f"search_type of {search_type} not allowed.")
if search_type not in (
allowed_search_types := ("similarity", "hybrid", "semantic_hybrid")
):
raise ValueError(
f"search_type of {search_type} not allowed. Valid values are: "
f"{allowed_search_types}"
)
return values
def _get_relevant_documents(

Loading…
Cancel
Save