[Small upgrade] Allow document limit in AzureCognitiveSearchRetriever (#7690)

Multiple people have asked in #5081 for a way to limit the documents
returned from an AzureCognitiveSearchRetriever. This PR adds the `top_n`
parameter to allow that.


Twitter handle:
 [@UmerHAdil](twitter.com/umerHAdil)
This commit is contained in:
UmerHA 2023-07-14 05:04:40 +02:00 committed by GitHub
parent af6d333147
commit 82f3e32d8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 2 deletions

View File

@ -91,7 +91,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")" "retriever = AzureCognitiveSearchRetriever(content_key=\"content\", top_k=10)"
] ]
}, },
{ {
@ -111,6 +111,36 @@
"source": [ "source": [
"retriever.get_relevant_documents(\"what is langchain\")" "retriever.get_relevant_documents(\"what is langchain\")"
] ]
},
{
"cell_type": "markdown",
"id": "72eca08e",
"metadata": {},
"source": [
"You can change the number of results returned with the `top_k` parameter. The default value is `None`, which returns all results. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "097146c5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d9963f5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "dc120696",
"metadata": {},
"source": []
} }
], ],
"metadata": { "metadata": {

View File

@ -33,6 +33,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
"""ClientSession, in case we want to reuse connection for better performance.""" """ClientSession, in case we want to reuse connection for better performance."""
content_key: str = "content" content_key: str = "content"
"""Key in a retrieved result to set as the Document page_content.""" """Key in a retrieved result to set as the Document page_content."""
top_k: Optional[int] = None
"""Number of results to retrieve. Set to None to retrieve all results."""
class Config: class Config:
extra = Extra.forbid extra = Extra.forbid
@ -55,7 +57,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
def _build_search_url(self, query: str) -> str: def _build_search_url(self, query: str) -> str:
base_url = f"https://{self.service_name}.search.windows.net/" base_url = f"https://{self.service_name}.search.windows.net/"
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}" endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
return base_url + endpoint_path + f"&search={query}" top_param = f"&$top={self.top_k}" if self.top_k else ""
return base_url + endpoint_path + f"&search={query}" + top_param
@property @property
def _headers(self) -> Dict[str, str]: def _headers(self) -> Dict[str, str]:

View File

@ -13,6 +13,10 @@ def test_azure_cognitive_search_get_relevant_documents() -> None:
assert isinstance(doc, Document) assert isinstance(doc, Document)
assert doc.page_content assert doc.page_content
retriever = AzureCognitiveSearchRetriever(top_k=1)
documents = retriever.get_relevant_documents("what is langchain")
assert len(documents) <= 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_azure_cognitive_search_aget_relevant_documents() -> None: async def test_azure_cognitive_search_aget_relevant_documents() -> None: