mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
community[patch]: Added type hinting to OpenSearch clients (#27946)
Description: * When working with OpenSearchVectorSearch to make OpenSearchGraphVectorStore (coming soon), I noticed that there wasn't type hinting for the underlying OpenSearch clients. This fixes that issue. * Confirmed tests are still passing with code changes. Note that there is some additional code duplication now, but I think this approach is cleaner overall.
This commit is contained in:
parent
4c2392e55c
commit
c421997caa
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
@ -23,57 +23,18 @@ SCRIPT_SCORING_SEARCH = "script_scoring"
|
||||
PAINLESS_SCRIPTING_SEARCH = "painless_scripting"
|
||||
MATCH_ALL_QUERY = {"match_all": {}} # type: Dict
|
||||
|
||||
|
||||
def _import_opensearch() -> Any:
|
||||
"""Import OpenSearch if available, otherwise raise error."""
|
||||
try:
|
||||
from opensearchpy import OpenSearch
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
return OpenSearch
|
||||
if TYPE_CHECKING:
|
||||
from opensearchpy import AsyncOpenSearch, OpenSearch
|
||||
|
||||
|
||||
def _import_async_opensearch() -> Any:
|
||||
"""Import AsyncOpenSearch if available, otherwise raise error."""
|
||||
try:
|
||||
from opensearchpy import AsyncOpenSearch
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||
return AsyncOpenSearch
|
||||
|
||||
|
||||
def _import_bulk() -> Any:
|
||||
"""Import bulk if available, otherwise raise error."""
|
||||
try:
|
||||
from opensearchpy.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
return bulk
|
||||
|
||||
|
||||
def _import_async_bulk() -> Any:
|
||||
"""Import async_bulk if available, otherwise raise error."""
|
||||
try:
|
||||
from opensearchpy.helpers import async_bulk
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||
return async_bulk
|
||||
|
||||
|
||||
def _import_not_found_error() -> Any:
|
||||
"""Import not found error if available, otherwise raise error."""
|
||||
try:
|
||||
from opensearchpy.exceptions import NotFoundError
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
return NotFoundError
|
||||
|
||||
|
||||
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
||||
def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> OpenSearch:
|
||||
"""Get OpenSearch client from the opensearch_url, otherwise raise error."""
|
||||
try:
|
||||
opensearch = _import_opensearch()
|
||||
client = opensearch(opensearch_url, **kwargs)
|
||||
from opensearchpy import OpenSearch
|
||||
|
||||
client = OpenSearch(opensearch_url, **kwargs)
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"OpenSearch client string provided is not in proper format. "
|
||||
@ -82,11 +43,14 @@ def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
||||
return client
|
||||
|
||||
|
||||
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any:
|
||||
def _get_async_opensearch_client(opensearch_url: str, **kwargs: Any) -> AsyncOpenSearch:
|
||||
"""Get AsyncOpenSearch client from the opensearch_url, otherwise raise error."""
|
||||
try:
|
||||
async_opensearch = _import_async_opensearch()
|
||||
client = async_opensearch(opensearch_url, **kwargs)
|
||||
from opensearchpy import AsyncOpenSearch
|
||||
|
||||
client = AsyncOpenSearch(opensearch_url, **kwargs)
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||
except ValueError as e:
|
||||
raise ImportError(
|
||||
f"AsyncOpenSearch client string provided is not in proper format. "
|
||||
@ -127,7 +91,7 @@ def _is_aoss_enabled(http_auth: Any) -> bool:
|
||||
|
||||
|
||||
def _bulk_ingest_embeddings(
|
||||
client: Any,
|
||||
client: OpenSearch,
|
||||
index_name: str,
|
||||
embeddings: List[List[float]],
|
||||
texts: Iterable[str],
|
||||
@ -142,16 +106,19 @@ def _bulk_ingest_embeddings(
|
||||
"""Bulk Ingest Embeddings into given index."""
|
||||
if not mapping:
|
||||
mapping = dict()
|
||||
try:
|
||||
from opensearchpy.exceptions import NotFoundError
|
||||
from opensearchpy.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
|
||||
bulk = _import_bulk()
|
||||
not_found_error = _import_not_found_error()
|
||||
requests = []
|
||||
return_ids = []
|
||||
mapping = mapping
|
||||
|
||||
try:
|
||||
client.indices.get(index=index_name)
|
||||
except not_found_error:
|
||||
except NotFoundError:
|
||||
client.indices.create(index=index_name, body=mapping)
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
@ -177,7 +144,7 @@ def _bulk_ingest_embeddings(
|
||||
|
||||
|
||||
async def _abulk_ingest_embeddings(
|
||||
client: Any,
|
||||
client: AsyncOpenSearch,
|
||||
index_name: str,
|
||||
embeddings: List[List[float]],
|
||||
texts: Iterable[str],
|
||||
@ -193,14 +160,18 @@ async def _abulk_ingest_embeddings(
|
||||
if not mapping:
|
||||
mapping = dict()
|
||||
|
||||
async_bulk = _import_async_bulk()
|
||||
not_found_error = _import_not_found_error()
|
||||
try:
|
||||
from opensearchpy.exceptions import NotFoundError
|
||||
from opensearchpy.helpers import async_bulk
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_ASYNC_OPENSEARCH_PY_ERROR)
|
||||
|
||||
requests = []
|
||||
return_ids = []
|
||||
|
||||
try:
|
||||
await client.indices.get(index=index_name)
|
||||
except not_found_error:
|
||||
except NotFoundError:
|
||||
await client.indices.create(index=index_name, body=mapping)
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
@ -230,7 +201,7 @@ async def _abulk_ingest_embeddings(
|
||||
def _default_scripting_text_mapping(
|
||||
dim: int,
|
||||
vector_field: str = "vector_field",
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Painless Scripting or Script Scoring,the default mapping to create index."""
|
||||
return {
|
||||
"mappings": {
|
||||
@ -249,7 +220,7 @@ def _default_text_mapping(
|
||||
ef_construction: int = 512,
|
||||
m: int = 16,
|
||||
vector_field: str = "vector_field",
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Approximate k-NN Search, this is the default mapping to create index."""
|
||||
return {
|
||||
"settings": {"index": {"knn": True, "knn.algo_param.ef_search": ef_search}},
|
||||
@ -275,7 +246,7 @@ def _default_approximate_search_query(
|
||||
k: int = 4,
|
||||
vector_field: str = "vector_field",
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Approximate k-NN Search, this is the default query."""
|
||||
return {
|
||||
"size": k,
|
||||
@ -291,7 +262,7 @@ def _approximate_search_query_with_boolean_filter(
|
||||
vector_field: str = "vector_field",
|
||||
subquery_clause: str = "must",
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Approximate k-NN Search, with Boolean Filter."""
|
||||
return {
|
||||
"size": k,
|
||||
@ -313,7 +284,7 @@ def _approximate_search_query_with_efficient_filter(
|
||||
k: int = 4,
|
||||
vector_field: str = "vector_field",
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
|
||||
Faiss Engines."""
|
||||
search_query = _default_approximate_search_query(
|
||||
@ -330,7 +301,7 @@ def _default_script_query(
|
||||
pre_filter: Optional[Dict] = None,
|
||||
vector_field: str = "vector_field",
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Script Scoring Search, this is the default query."""
|
||||
|
||||
if not pre_filter:
|
||||
@ -376,7 +347,7 @@ def _default_painless_scripting_query(
|
||||
pre_filter: Optional[Dict] = None,
|
||||
vector_field: str = "vector_field",
|
||||
score_threshold: Optional[float] = 0.0,
|
||||
) -> Dict:
|
||||
) -> Dict[str, Any]:
|
||||
"""For Painless Scripting Search, this is the default query."""
|
||||
|
||||
if not pre_filter:
|
||||
@ -692,7 +663,10 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
refresh_indices: Whether to refresh the index
|
||||
after deleting documents. Defaults to True.
|
||||
"""
|
||||
bulk = _import_bulk()
|
||||
try:
|
||||
from opensearchpy.helpers import bulk
|
||||
except ImportError:
|
||||
raise ImportError(IMPORT_OPENSEARCH_PY_ERROR)
|
||||
|
||||
body = []
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user