community:support additional Azure Search Options (#24134)

- **Description:** Support additional kwargs options for the Azure
Search client (Described here
https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations)
    - **Issue:** N/A
    - **Dependencies:** No additional Dependencies

---------
This commit is contained in:
Matt 2024-07-11 11:22:36 -07:00 committed by GitHub
parent 122e80e04d
commit 8327925ab7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 2 deletions

View File

@ -169,6 +169,23 @@
")" ")"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations\n",
"vector_store: AzureSearch = AzureSearch(\n",
" azure_search_endpoint=vector_store_address,\n",
" azure_search_key=vector_store_password,\n",
" index_name=index_name,\n",
" embedding_function=embeddings.embed_query,\n",
" # Configure max retries for the Azure client\n",
" additional_search_client_options={\"retry_total\": 4},\n",
")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},

View File

@ -86,6 +86,7 @@ def _get_search_client(
user_agent: Optional[str] = "langchain", user_agent: Optional[str] = "langchain",
cors_options: Optional[CorsOptions] = None, cors_options: Optional[CorsOptions] = None,
async_: bool = False, async_: bool = False,
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> Union[SearchClient, AsyncSearchClient]: ) -> Union[SearchClient, AsyncSearchClient]:
from azure.core.credentials import AzureKeyCredential from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ResourceNotFoundError from azure.core.exceptions import ResourceNotFoundError
@ -109,6 +110,7 @@ def _get_search_client(
VectorSearchProfile, VectorSearchProfile,
) )
additional_search_client_options = additional_search_client_options or {}
default_fields = default_fields or [] default_fields = default_fields or []
if key is None: if key is None:
credential = DefaultAzureCredential() credential = DefaultAzureCredential()
@ -225,6 +227,7 @@ def _get_search_client(
index_name=index_name, index_name=index_name,
credential=credential, credential=credential,
user_agent=user_agent, user_agent=user_agent,
**additional_search_client_options,
) )
else: else:
return AsyncSearchClient( return AsyncSearchClient(
@ -232,6 +235,7 @@ def _get_search_client(
index_name=index_name, index_name=index_name,
credential=credential, credential=credential,
user_agent=user_agent, user_agent=user_agent,
**additional_search_client_options,
) )
@ -256,6 +260,7 @@ class AzureSearch(VectorStore):
cors_options: Optional[CorsOptions] = None, cors_options: Optional[CorsOptions] = None,
*, *,
vector_search_dimensions: Optional[int] = None, vector_search_dimensions: Optional[int] = None,
additional_search_client_options: Optional[Dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
): ):
try: try:
@ -320,6 +325,7 @@ class AzureSearch(VectorStore):
default_fields=default_fields, default_fields=default_fields,
user_agent=user_agent, user_agent=user_agent,
cors_options=cors_options, cors_options=cors_options,
additional_search_client_options=additional_search_client_options,
) )
self.search_type = search_type self.search_type = search_type
self.semantic_configuration_name = semantic_configuration_name self.semantic_configuration_name = semantic_configuration_name

View File

@ -1,5 +1,5 @@
import json import json
from typing import List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -121,12 +121,15 @@ def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
) )
def create_vector_store() -> AzureSearch: def create_vector_store(
additional_search_client_options: Optional[Dict[str, Any]] = None,
) -> AzureSearch:
return AzureSearch( return AzureSearch(
azure_search_endpoint=DEFAULT_ENDPOINT, azure_search_endpoint=DEFAULT_ENDPOINT,
azure_search_key=DEFAULT_KEY, azure_search_key=DEFAULT_KEY,
index_name=DEFAULT_INDEX_NAME, index_name=DEFAULT_INDEX_NAME,
embedding_function=DEFAULT_EMBEDDING_MODEL, embedding_function=DEFAULT_EMBEDDING_MODEL,
additional_search_client_options=additional_search_client_options,
) )
@ -168,3 +171,20 @@ def test_init_new_index() -> None:
assert json.dumps(created_index.as_dict()) == json.dumps( assert json.dumps(created_index.as_dict()) == json.dumps(
mock_default_index().as_dict() mock_default_index().as_dict()
) )
@pytest.mark.requires("azure.search.documents")
def test_additional_search_options() -> None:
from azure.search.documents.indexes import SearchIndexClient
def mock_create_index() -> None:
pytest.fail("Should not create index in this test")
with patch.multiple(
SearchIndexClient, get_index=mock_default_index, create_index=mock_create_index
):
vector_store = create_vector_store(
additional_search_client_options={"api_version": "test"}
)
assert vector_store.client is not None
assert vector_store.client._api_version == "test"