diff --git a/docs/docs/integrations/vectorstores/azuresearch.ipynb b/docs/docs/integrations/vectorstores/azuresearch.ipynb index 06c8c81263..6398b46bf3 100644 --- a/docs/docs/integrations/vectorstores/azuresearch.ipynb +++ b/docs/docs/integrations/vectorstores/azuresearch.ipynb @@ -169,6 +169,23 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations\n", + "vector_store: AzureSearch = AzureSearch(\n", + " azure_search_endpoint=vector_store_address,\n", + " azure_search_key=vector_store_password,\n", + " index_name=index_name,\n", + " embedding_function=embeddings.embed_query,\n", + " # Configure max retries for the Azure client\n", + " additional_search_client_options={\"retry_total\": 4},\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/libs/community/langchain_community/vectorstores/azuresearch.py b/libs/community/langchain_community/vectorstores/azuresearch.py index 9543120c3c..8f1946a7f5 100644 --- a/libs/community/langchain_community/vectorstores/azuresearch.py +++ b/libs/community/langchain_community/vectorstores/azuresearch.py @@ -86,6 +86,7 @@ def _get_search_client( user_agent: Optional[str] = "langchain", cors_options: Optional[CorsOptions] = None, async_: bool = False, + additional_search_client_options: Optional[Dict[str, Any]] = None, ) -> Union[SearchClient, AsyncSearchClient]: from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError @@ -109,6 +110,7 @@ def _get_search_client( VectorSearchProfile, ) + additional_search_client_options = additional_search_client_options or {} default_fields = default_fields or [] if key is None: credential = DefaultAzureCredential() @@ -225,6 +227,7 @@ def _get_search_client( index_name=index_name, credential=credential, user_agent=user_agent, + **additional_search_client_options, ) else: return AsyncSearchClient( @@ -232,6 +235,7 @@ def _get_search_client( index_name=index_name, credential=credential, user_agent=user_agent, + **additional_search_client_options, ) @@ -256,6 +260,7 @@ class AzureSearch(VectorStore): cors_options: Optional[CorsOptions] = None, *, vector_search_dimensions: Optional[int] = None, + additional_search_client_options: Optional[Dict[str, Any]] = None, **kwargs: Any, ): try: @@ -320,6 +325,7 @@ class AzureSearch(VectorStore): default_fields=default_fields, user_agent=user_agent, cors_options=cors_options, + additional_search_client_options=additional_search_client_options, ) self.search_type = search_type self.semantic_configuration_name = semantic_configuration_name diff --git a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py index 25aabb8759..a06fbfd151 100644 --- a/libs/community/tests/unit_tests/vectorstores/test_azure_search.py +++ b/libs/community/tests/unit_tests/vectorstores/test_azure_search.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import Any, Dict, List, Optional from unittest.mock import patch import pytest @@ -121,12 +121,15 @@ def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def] ) -def create_vector_store() -> AzureSearch: +def create_vector_store( + additional_search_client_options: Optional[Dict[str, Any]] = None, +) -> AzureSearch: return AzureSearch( azure_search_endpoint=DEFAULT_ENDPOINT, azure_search_key=DEFAULT_KEY, index_name=DEFAULT_INDEX_NAME, embedding_function=DEFAULT_EMBEDDING_MODEL, + additional_search_client_options=additional_search_client_options, ) @@ -168,3 +171,20 @@ def test_init_new_index() -> None: assert json.dumps(created_index.as_dict()) == json.dumps( mock_default_index().as_dict() ) + + +@pytest.mark.requires("azure.search.documents") +def test_additional_search_options() -> None: + from azure.search.documents.indexes import SearchIndexClient + + def mock_create_index() -> None: + pytest.fail("Should not create index in this test") + + with patch.multiple( + SearchIndexClient, get_index=mock_default_index, create_index=mock_create_index + ): + vector_store = create_vector_store( + additional_search_client_options={"api_version": "test"} + ) + assert vector_store.client is not None + assert vector_store.client._api_version == "test"