community[patch]: Added type hinting to OpenSearch clients (#27946)

Description:
* When working with OpenSearchVectorSearch to make
OpenSearchGraphVectorStore (coming soon), I noticed that there wasn't
type hinting for the underlying OpenSearch clients. This fixes that
issue.
* Confirmed tests are still passing with code changes.

Note that there is some additional code duplication now, but I think
this approach is cleaner overall.
This commit is contained in:
Eric Pinzur 2024-11-08 20:04:57 +01:00 committed by GitHub
parent 4c2392e55c
commit c421997caa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import uuid
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain_core.documents import Document
@ -23,57 +23,18 @@ SCRIPT_SCORING_SEARCH = "script_scoring"
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
def _import_opensearch() -> Any:
"""Import OpenSearch if available, otherwise raise error."""
try:
from opensearchpy import OpenSearch
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return OpenSearch
if TYPE_CHECKING:
from opensearchpy import AsyncOpenSearch, OpenSearch
def _import_async_opensearch() -> Any:
"""Import AsyncOpenSearch if available, otherwise raise error."""
try:
from opensearchpy import AsyncOpenSearch
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return AsyncOpenSearch
def _import_bulk() -> Any:
"""Import bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return bulk
def _import_async_bulk() -> Any:
"""Import async_bulk if available, otherwise raise error."""
try:
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
return async_bulk
def _import_not_found_error() -> Any:
"""Import not found error if available, otherwise raise error."""
try:
from opensearchpy.exceptions import NotFoundError
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
return NotFoundError
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch:
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
try:
opensearch = _import_opensearch()
client = opensearch(opensearch_url, **kwargs)
from opensearchpy import OpenSearch
client = OpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
except ValueError as e:
raise ImportError(
f"OpenSearch client string provided is not in proper format. "
@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
return client
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch:
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
try:
async_opensearch = _import_async_opensearch()
client = async_opensearch(opensearch_url, **kwargs)
from opensearchpy import AsyncOpenSearch
client = AsyncOpenSearch(opensearch_url, **kwargs)
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
except ValueError as e:
raise ImportError(
f"AsyncOpenSearch client string provided is not in proper format. "
@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool:
def _bulk_ingest_embeddings(
client: Any,
client: OpenSearch,
index_name: str,
embeddings: List[List[float]],
texts: Iterable[str],
@ -142,16 +106,19 @@ def _bulk_ingest_embeddings(
"""Bulk Ingest Embeddings into given index."""
if not mapping:
mapping = dict()
try:
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
bulk = _import_bulk()
not_found_error = _import_not_found_error()
requests = []
return_ids = []
mapping = mapping
try:
client.indices.get(index=index_name)
except not_found_error:
except NotFoundError:
client.indices.create(index=index_name, body=mapping)
for i, text in enumerate(texts):
@ -177,7 +144,7 @@ def _bulk_ingest_embeddings(
async def _abulk_ingest_embeddings(
client: Any,
client: AsyncOpenSearch,
index_name: str,
embeddings: List[List[float]],
texts: Iterable[str],
@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings(
if not mapping:
mapping = dict()
async_bulk = _import_async_bulk()
not_found_error = _import_not_found_error()
try:
from opensearchpy.exceptions import NotFoundError
from opensearchpy.helpers import async_bulk
except ImportError:
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
requests = []
return_ids = []
try:
await client.indices.get(index=index_name)
except not_found_error:
except NotFoundError:
await client.indices.create(index=index_name, body=mapping)
for i, text in enumerate(texts):
@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings(
def _default_scripting_text_mapping(
dim: int,
vector_field: str = "vector_field",
) -> Dict:
) -> Dict[str, Any]:
"""For Painless Scripting or Script Scoring,the default mapping to create index."""
return {
"mappings": {
@ -249,7 +220,7 @@ def _default_text_mapping(
ef_construction: int = 512,
m: int = 16,
vector_field: str = "vector_field",
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default mapping to create index."""
return {
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
@ -275,7 +246,7 @@ def _default_approximate_search_query(
k: int = 4,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, this is the default query."""
return {
"size": k,
@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter(
vector_field: str = "vector_field",
subquery_clause: str = "must",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Boolean Filter."""
return {
"size": k,
@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter(
k: int = 4,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
Faiss Engines."""
search_query = _default_approximate_search_query(
@ -330,7 +301,7 @@ def _default_script_query(
pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Script Scoring Search, this is the default query."""
if not pre_filter:
@ -376,7 +347,7 @@ def _default_painless_scripting_query(
pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field",
score_threshold: Optional[float] = 0.0,
) -> Dict:
) -> Dict[str, Any]:
"""For Painless Scripting Search, this is the default query."""
if not pre_filter:
@ -692,7 +663,10 @@ class OpenSearchVectorSearch(VectorStore):
refresh_indices: Whether to refresh the index
after deleting documents. Defaults to True.
"""
bulk = _import_bulk()
try:
from opensearchpy.helpers import bulk
except ImportError:
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
body = []