Remove `_get_kwarg_value` function (#13184)

`_get_kwarg_value` function is useless, one can rely on python builtin
functionalities to do the exact same thing.

- **Description:** Removed `_get_kwarg_value`. Helps with code
readability.
  - **Issue:** the issue # it fixes (if applicable),
  - **Twitter handle:** @Guillem_96
pull/8920/head
Guillem Orellana Trullols 11 months ago committed by GitHub
parent e1c020dfe1
commit 0f31cd8b49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -306,13 +306,6 @@ def _default_painless_scripting_query(
}
def _get_kwargs_value(kwargs: Any, key: str, default_value: Any) -> Any:
"""Get the value of the key if present. Else get the default_value."""
if key in kwargs:
return kwargs.get(key)
return default_value
class OpenSearchVectorSearch(VectorStore):
"""`Amazon OpenSearch Vector Engine` vector store.
@ -338,10 +331,10 @@ class OpenSearchVectorSearch(VectorStore):
"""Initialize with necessary components."""
self.embedding_function = embedding_function
self.index_name = index_name
http_auth = _get_kwargs_value(kwargs, "http_auth", None)
http_auth = kwargs.get("http_auth")
self.is_aoss = _is_aoss_enabled(http_auth=http_auth)
self.client = _get_opensearch_client(opensearch_url, **kwargs)
self.engine = _get_kwargs_value(kwargs, "engine", None)
self.engine = kwargs.get("engine")
@property
def embeddings(self) -> Embeddings:
@ -357,16 +350,16 @@ class OpenSearchVectorSearch(VectorStore):
**kwargs: Any,
) -> List[str]:
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
index_name = _get_kwargs_value(kwargs, "index_name", self.index_name)
text_field = _get_kwargs_value(kwargs, "text_field", "text")
index_name = kwargs.get("index_name", self.index_name)
text_field = kwargs.get("text_field", "text")
dim = len(embeddings[0])
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
ef_search = _get_kwargs_value(kwargs, "ef_search", 512)
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
m = _get_kwargs_value(kwargs, "m", 16)
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
engine = kwargs.get("engine", "nmslib")
space_type = kwargs.get("space_type", "l2")
ef_search = kwargs.get("ef_search", 512)
ef_construction = kwargs.get("ef_construction", 512)
m = kwargs.get("m", 16)
vector_field = kwargs.get("vector_field", "vector_field")
max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024)
_validate_aoss_with_engines(self.is_aoss, engine)
@ -542,8 +535,8 @@ class OpenSearchVectorSearch(VectorStore):
same as `similarity_search`
"""
text_field = _get_kwargs_value(kwargs, "text_field", "text")
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
text_field = kwargs.get("text_field", "text")
metadata_field = kwargs.get("metadata_field", "metadata")
hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
@ -581,10 +574,10 @@ class OpenSearchVectorSearch(VectorStore):
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
index_name = _get_kwargs_value(kwargs, "index_name", self.index_name)
filter = _get_kwargs_value(kwargs, "filter", {})
search_type = kwargs.get("search_type", "approximate_search")
vector_field = kwargs.get("vector_field", "vector_field")
index_name = kwargs.get("index_name", self.index_name)
filter = kwargs.get("filter", {})
if (
self.is_aoss
@ -597,11 +590,11 @@ class OpenSearchVectorSearch(VectorStore):
)
if search_type == "approximate_search":
boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {})
subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must")
efficient_filter = _get_kwargs_value(kwargs, "efficient_filter", {})
boolean_filter = kwargs.get("boolean_filter", {})
subquery_clause = kwargs.get("subquery_clause", "must")
efficient_filter = kwargs.get("efficient_filter", {})
# `lucene_filter` is deprecated, added for Backwards Compatibility
lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {})
lucene_filter = kwargs.get("lucene_filter", {})
if boolean_filter != {} and efficient_filter != {}:
raise ValueError(
@ -657,14 +650,14 @@ class OpenSearchVectorSearch(VectorStore):
embedding, k=k, vector_field=vector_field
)
elif search_type == SCRIPT_SCORING_SEARCH:
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
space_type = kwargs.get("space_type", "l2")
pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY)
search_query = _default_script_query(
embedding, k, space_type, pre_filter, vector_field
)
elif search_type == PAINLESS_SCRIPTING_SEARCH:
space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
space_type = kwargs.get("space_type", "l2Squared")
pre_filter = kwargs.get("pre_filter", MATCH_ALL_QUERY)
search_query = _default_painless_scripting_query(
embedding, k, space_type, pre_filter, vector_field
)
@ -701,9 +694,9 @@ class OpenSearchVectorSearch(VectorStore):
List of Documents selected by maximal marginal relevance.
"""
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
vector_field = kwargs.get("vector_field", "vector_field")
text_field = kwargs.get("text_field", "text")
metadata_field = kwargs.get("metadata_field", "metadata")
# Get embedding of the user query
embedding = self.embedding_function.embed_query(query)
@ -874,11 +867,11 @@ class OpenSearchVectorSearch(VectorStore):
index_name = get_from_dict_or_env(
kwargs, "index_name", "OPENSEARCH_INDEX_NAME", default=uuid.uuid4().hex
)
is_appx_search = _get_kwargs_value(kwargs, "is_appx_search", True)
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
http_auth = _get_kwargs_value(kwargs, "http_auth", None)
is_appx_search = kwargs.get("is_appx_search", True)
vector_field = kwargs.get("vector_field", "vector_field")
text_field = kwargs.get("text_field", "text")
max_chunk_bytes = kwargs.get("max_chunk_bytes", 1 * 1024 * 1024)
http_auth = kwargs.get("http_auth")
is_aoss = _is_aoss_enabled(http_auth=http_auth)
engine = None
@ -889,11 +882,11 @@ class OpenSearchVectorSearch(VectorStore):
)
if is_appx_search:
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
ef_search = _get_kwargs_value(kwargs, "ef_search", 512)
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
m = _get_kwargs_value(kwargs, "m", 16)
engine = kwargs.get("engine", "nmslib")
space_type = kwargs.get("space_type", "l2")
ef_search = kwargs.get("ef_search", 512)
ef_construction = kwargs.get("ef_construction", 512)
m = kwargs.get("m", 16)
_validate_aoss_with_engines(is_aoss, engine)

Loading…
Cancel
Save