diff --git a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py index fe3cb45ad7..1e688a5c65 100644 --- a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py +++ b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py @@ -235,6 +235,7 @@ def _approximate_search_query_with_efficient_filter( def _default_script_query( query_vector: List[float], + k: int = 4, space_type: str = "l2", pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", @@ -245,6 +246,7 @@ def _default_script_query( pre_filter = MATCH_ALL_QUERY return { + "size": k, "query": { "script_score": { "query": pre_filter, @@ -258,7 +260,7 @@ def _default_script_query( }, }, } - } + }, } @@ -283,6 +285,7 @@ def __get_painless_scripting_source( def _default_painless_scripting_query( query_vector: List[float], + k: int = 4, space_type: str = "l2Squared", pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", @@ -292,8 +295,11 @@ def _default_painless_scripting_query( if not pre_filter: pre_filter = MATCH_ALL_QUERY - source = __get_painless_scripting_source(space_type, query_vector) + source = __get_painless_scripting_source( + space_type, query_vector, vector_field=vector_field + ) return { + "size": k, "query": { "script_score": { "query": pre_filter, @@ -305,7 +311,7 @@ def _default_painless_scripting_query( }, }, } - } + }, } @@ -593,20 +599,20 @@ class OpenSearchVectorSearch(VectorStore): space_type = _get_kwargs_value(kwargs, "space_type", "l2") pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) search_query = _default_script_query( - embedding, space_type, pre_filter, vector_field + embedding, k, space_type, pre_filter, vector_field ) elif search_type == PAINLESS_SCRIPTING_SEARCH: space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared") pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) search_query = _default_painless_scripting_query( - embedding, space_type, pre_filter, vector_field + embedding, k, space_type, pre_filter, vector_field ) else: raise ValueError("Invalid `search_type` provided as an argument") response = self.client.search(index=self.index_name, body=search_query) - return [hit for hit in response["hits"]["hits"][:k]] + return [hit for hit in response["hits"]["hits"]] def max_marginal_relevance_search( self,