From 2542a09abc63fe0dfd86c3626c9c78fb37dfc46f Mon Sep 17 00:00:00 2001 From: Massimiliano Pronesti Date: Thu, 18 Apr 2024 22:06:47 +0200 Subject: [PATCH] community[patch]: AzureSearch incorrectly converted to retriever (#20601) Closes #20600. Please see the issue for more details. --- .../vectorstores/azuresearch.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index e2b23e8a19..b07628d5a7 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -654,6 +654,31 @@ class AzureSearch(VectorStore): azure_search.add_texts(texts, metadatas, **kwargs) return azure_search + def as_retriever(self, **kwargs: Any) -> AzureSearchVectorStoreRetriever: # type: ignore + """Return AzureSearchVectorStoreRetriever initialized from this VectorStore. + + Args: + search_type (Optional[str]): Defines the type of search that + the Retriever should perform. + Can be "similarity" (default), "hybrid", or + "semantic_hybrid". + search_kwargs (Optional[Dict]): Keyword arguments to pass to the + search function. Can include things like: + k: Amount of documents to return (Default: 4) + score_threshold: Minimum relevance threshold + for similarity_score_threshold + fetch_k: Amount of documents to pass to MMR algorithm (Default: 20) + lambda_mult: Diversity of results returned by MMR; + 1 for minimum diversity and 0 for maximum. (Default: 0.5) + filter: Filter by document metadata + + Returns: + AzureSearchVectorStoreRetriever: Retriever class for VectorStore. + """ + tags = kwargs.pop("tags", None) or [] + tags.extend(self._get_retriever_tags()) + return AzureSearchVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags) + class AzureSearchVectorStoreRetriever(BaseRetriever): """Retriever that uses `Azure Cognitive Search`.""" @@ -676,8 +701,13 @@ class AzureSearchVectorStoreRetriever(BaseRetriever): """Validate search type.""" if "search_type" in values: search_type = values["search_type"] - if search_type not in ("similarity", "hybrid", "semantic_hybrid"): - raise ValueError(f"search_type of {search_type} not allowed.") + if search_type not in ( + allowed_search_types := ("similarity", "hybrid", "semantic_hybrid") + ): + raise ValueError( + f"search_type of {search_type} not allowed. Valid values are: " + f"{allowed_search_types}" + ) return values def _get_relevant_documents(