mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community:support additional Azure Search Options (#24134)
- **Description:** Support additional kwargs options for the Azure Search client (Described here https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations) - **Issue:** N/A - **Dependencies:** No additional Dependencies ---------
This commit is contained in:
parent
122e80e04d
commit
8327925ab7
@ -169,6 +169,23 @@
|
|||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Specify additional properties for the Azure client such as the following https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/core/azure-core/README.md#configurations\n",
|
||||||
|
"vector_store: AzureSearch = AzureSearch(\n",
|
||||||
|
" azure_search_endpoint=vector_store_address,\n",
|
||||||
|
" azure_search_key=vector_store_password,\n",
|
||||||
|
" index_name=index_name,\n",
|
||||||
|
" embedding_function=embeddings.embed_query,\n",
|
||||||
|
" # Configure max retries for the Azure client\n",
|
||||||
|
" additional_search_client_options={\"retry_total\": 4},\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -86,6 +86,7 @@ def _get_search_client(
|
|||||||
user_agent: Optional[str] = "langchain",
|
user_agent: Optional[str] = "langchain",
|
||||||
cors_options: Optional[CorsOptions] = None,
|
cors_options: Optional[CorsOptions] = None,
|
||||||
async_: bool = False,
|
async_: bool = False,
|
||||||
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
) -> Union[SearchClient, AsyncSearchClient]:
|
) -> Union[SearchClient, AsyncSearchClient]:
|
||||||
from azure.core.credentials import AzureKeyCredential
|
from azure.core.credentials import AzureKeyCredential
|
||||||
from azure.core.exceptions import ResourceNotFoundError
|
from azure.core.exceptions import ResourceNotFoundError
|
||||||
@ -109,6 +110,7 @@ def _get_search_client(
|
|||||||
VectorSearchProfile,
|
VectorSearchProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
additional_search_client_options = additional_search_client_options or {}
|
||||||
default_fields = default_fields or []
|
default_fields = default_fields or []
|
||||||
if key is None:
|
if key is None:
|
||||||
credential = DefaultAzureCredential()
|
credential = DefaultAzureCredential()
|
||||||
@ -225,6 +227,7 @@ def _get_search_client(
|
|||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
credential=credential,
|
credential=credential,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
**additional_search_client_options,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return AsyncSearchClient(
|
return AsyncSearchClient(
|
||||||
@ -232,6 +235,7 @@ def _get_search_client(
|
|||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
credential=credential,
|
credential=credential,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
**additional_search_client_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -256,6 +260,7 @@ class AzureSearch(VectorStore):
|
|||||||
cors_options: Optional[CorsOptions] = None,
|
cors_options: Optional[CorsOptions] = None,
|
||||||
*,
|
*,
|
||||||
vector_search_dimensions: Optional[int] = None,
|
vector_search_dimensions: Optional[int] = None,
|
||||||
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@ -320,6 +325,7 @@ class AzureSearch(VectorStore):
|
|||||||
default_fields=default_fields,
|
default_fields=default_fields,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
cors_options=cors_options,
|
cors_options=cors_options,
|
||||||
|
additional_search_client_options=additional_search_client_options,
|
||||||
)
|
)
|
||||||
self.search_type = search_type
|
self.search_type = search_type
|
||||||
self.semantic_configuration_name = semantic_configuration_name
|
self.semantic_configuration_name = semantic_configuration_name
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -121,12 +121,15 @@ def mock_default_index(*args, **kwargs): # type: ignore[no-untyped-def]
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_vector_store() -> AzureSearch:
|
def create_vector_store(
|
||||||
|
additional_search_client_options: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AzureSearch:
|
||||||
return AzureSearch(
|
return AzureSearch(
|
||||||
azure_search_endpoint=DEFAULT_ENDPOINT,
|
azure_search_endpoint=DEFAULT_ENDPOINT,
|
||||||
azure_search_key=DEFAULT_KEY,
|
azure_search_key=DEFAULT_KEY,
|
||||||
index_name=DEFAULT_INDEX_NAME,
|
index_name=DEFAULT_INDEX_NAME,
|
||||||
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
embedding_function=DEFAULT_EMBEDDING_MODEL,
|
||||||
|
additional_search_client_options=additional_search_client_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -168,3 +171,20 @@ def test_init_new_index() -> None:
|
|||||||
assert json.dumps(created_index.as_dict()) == json.dumps(
|
assert json.dumps(created_index.as_dict()) == json.dumps(
|
||||||
mock_default_index().as_dict()
|
mock_default_index().as_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("azure.search.documents")
|
||||||
|
def test_additional_search_options() -> None:
|
||||||
|
from azure.search.documents.indexes import SearchIndexClient
|
||||||
|
|
||||||
|
def mock_create_index() -> None:
|
||||||
|
pytest.fail("Should not create index in this test")
|
||||||
|
|
||||||
|
with patch.multiple(
|
||||||
|
SearchIndexClient, get_index=mock_default_index, create_index=mock_create_index
|
||||||
|
):
|
||||||
|
vector_store = create_vector_store(
|
||||||
|
additional_search_client_options={"api_version": "test"}
|
||||||
|
)
|
||||||
|
assert vector_store.client is not None
|
||||||
|
assert vector_store.client._api_version == "test"
|
||||||
|
Loading…
Reference in New Issue
Block a user