From 82f3e32d8dee7131bf412b481734df6427b93ff3 Mon Sep 17 00:00:00 2001 From: UmerHA <40663591+UmerHA@users.noreply.github.com> Date: Fri, 14 Jul 2023 05:04:40 +0200 Subject: [PATCH] [Small upgrade] Allow document limit in AzureCognitiveSearchRetriever (#7690) Multiple people have asked in #5081 for a way to limit the documents returned from an AzureCognitiveSearchRetriever. This PR adds the `top_n` parameter to allow that. Twitter handle: [@UmerHAdil](twitter.com/umerHAdil) --- .../integrations/azure_cognitive_search.ipynb | 32 ++++++++++++++++++- .../retrievers/azure_cognitive_search.py | 5 ++- .../retrievers/test_azure_cognitive_search.py | 4 +++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/docs/extras/modules/data_connection/retrievers/integrations/azure_cognitive_search.ipynb b/docs/extras/modules/data_connection/retrievers/integrations/azure_cognitive_search.ipynb index 7ceb431ea6..9b09e63464 100644 --- a/docs/extras/modules/data_connection/retrievers/integrations/azure_cognitive_search.ipynb +++ b/docs/extras/modules/data_connection/retrievers/integrations/azure_cognitive_search.ipynb @@ -91,7 +91,7 @@ "metadata": {}, "outputs": [], "source": [ - "retriever = AzureCognitiveSearchRetriever(content_key=\"content\")" + "retriever = AzureCognitiveSearchRetriever(content_key=\"content\", top_k=10)" ] }, { @@ -111,6 +111,36 @@ "source": [ "retriever.get_relevant_documents(\"what is langchain\")" ] + }, + { + "cell_type": "markdown", + "id": "72eca08e", + "metadata": {}, + "source": [ + "You can change the number of results returned with the `top_k` parameter. The default value is `None`, which returns all results. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "097146c5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d9963f5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "dc120696", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/langchain/retrievers/azure_cognitive_search.py b/langchain/retrievers/azure_cognitive_search.py index efaa0a8792..02c92b8c19 100644 --- a/langchain/retrievers/azure_cognitive_search.py +++ b/langchain/retrievers/azure_cognitive_search.py @@ -33,6 +33,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever): """ClientSession, in case we want to reuse connection for better performance.""" content_key: str = "content" """Key in a retrieved result to set as the Document page_content.""" + top_k: Optional[int] = None + """Number of results to retrieve. Set to None to retrieve all results.""" class Config: extra = Extra.forbid @@ -55,7 +57,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever): def _build_search_url(self, query: str) -> str: base_url = f"https://{self.service_name}.search.windows.net/" endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" - return base_url + endpoint_path + f"&search={query}" + top_param = f"&$top={self.top_k}" if self.top_k else "" + return base_url + endpoint_path + f"&search={query}" + top_param @property def _headers(self) -> Dict[str, str]: diff --git a/tests/integration_tests/retrievers/test_azure_cognitive_search.py b/tests/integration_tests/retrievers/test_azure_cognitive_search.py index 64d35550cb..effa1b7932 100644 --- a/tests/integration_tests/retrievers/test_azure_cognitive_search.py +++ b/tests/integration_tests/retrievers/test_azure_cognitive_search.py @@ -13,6 +13,10 @@ def test_azure_cognitive_search_get_relevant_documents() -> None: assert isinstance(doc, Document) assert doc.page_content + retriever = AzureCognitiveSearchRetriever(top_k=1) + documents = retriever.get_relevant_documents("what is langchain") + assert len(documents) <= 1 + @pytest.mark.asyncio async def test_azure_cognitive_search_aget_relevant_documents() -> None: