[Small upgrade] Allow document limit in AzureCognitiveSearchRetriever (#7690)

Multiple people have asked in #5081 for a way to limit the documents
returned from an AzureCognitiveSearchRetriever. This PR adds the `top_n`
parameter to allow that.


Twitter handle:
 [@UmerHAdil](twitter.com/umerHAdil)
pull/7694/head
UmerHA 1 year ago committed by GitHub
parent af6d333147
commit 82f3e32d8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,7 +91,7 @@
"metadata": {},
"outputs": [],
"source": [
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\")"
"retriever = AzureCognitiveSearchRetriever(content_key=\"content\", top_k=10)"
]
},
{
@ -111,6 +111,36 @@
"source": [
"retriever.get_relevant_documents(\"what is langchain\")"
]
},
{
"cell_type": "markdown",
"id": "72eca08e",
"metadata": {},
"source": [
"You can change the number of results returned with the `top_k` parameter. The default value is `None`, which returns all results. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "097146c5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d9963f5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "dc120696",
"metadata": {},
"source": []
}
],
"metadata": {

@ -33,6 +33,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
"""ClientSession, in case we want to reuse connection for better performance."""
content_key: str = "content"
"""Key in a retrieved result to set as the Document page_content."""
top_k: Optional[int] = None
"""Number of results to retrieve. Set to None to retrieve all results."""
class Config:
extra = Extra.forbid
@ -55,7 +57,8 @@ class AzureCognitiveSearchRetriever(BaseRetriever):
def _build_search_url(self, query: str) -> str:
base_url = f"https://{self.service_name}.search.windows.net/"
endpoint_path = f"indexes/{self.index_name}/docs?api-version={self.api_version}"
return base_url + endpoint_path + f"&search={query}"
top_param = f"&$top={self.top_k}" if self.top_k else ""
return base_url + endpoint_path + f"&search={query}" + top_param
@property
def _headers(self) -> Dict[str, str]:

@ -13,6 +13,10 @@ def test_azure_cognitive_search_get_relevant_documents() -> None:
assert isinstance(doc, Document)
assert doc.page_content
retriever = AzureCognitiveSearchRetriever(top_k=1)
documents = retriever.get_relevant_documents("what is langchain")
assert len(documents) <= 1
@pytest.mark.asyncio
async def test_azure_cognitive_search_aget_relevant_documents() -> None:

Loading…
Cancel
Save