Issue 8081 Fix query results size bug. Other bug: pass vector_field param. (#8085)

@baskaryan
#8081 

Likely the reason why the issue occurred is that OpenSearch's default k
is 10, so it needs to be specified.

Here's a similar question about its cousin ElasticSearch

https://discuss.elastic.co/t/elasticsearch-returns-only-10-records-but-the-hit-is-507/136605

I tested this manually and also fixed the same issue in
`_default_painless_scripting_query`. In addition,
`_default_painless_scripting_query` was not passing the `vector_field`
name to a sub call, so I fixed that too.


![image](https://github.com/hwchase17/langchain/assets/32244272/cfb7aad1-f701-49d9-9beb-a723aa276817)

I also tested this in the aws opensearch developer tools.


![image](https://github.com/hwchase17/langchain/assets/32244272/24544682-1578-4bbb-9eb5-980463c5b41b)

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
aerickson-clt 2023-08-04 01:41:11 -04:00 committed by GitHub
parent 812419d946
commit c7ea6e9ff8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -235,6 +235,7 @@ def _approximate_search_query_with_efficient_filter(
def _default_script_query( def _default_script_query(
query_vector: List[float], query_vector: List[float],
k: int = 4,
space_type: str = "l2", space_type: str = "l2",
pre_filter: Optional[Dict] = None, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
@ -245,6 +246,7 @@ def _default_script_query(
pre_filter = MATCH_ALL_QUERY pre_filter = MATCH_ALL_QUERY
return { return {
"size": k,
"query": { "query": {
"script_score": { "script_score": {
"query": pre_filter, "query": pre_filter,
@ -258,7 +260,7 @@ def _default_script_query(
}, },
}, },
} }
} },
} }
@ -283,6 +285,7 @@ def __get_painless_scripting_source(
def _default_painless_scripting_query( def _default_painless_scripting_query(
query_vector: List[float], query_vector: List[float],
k: int = 4,
space_type: str = "l2Squared", space_type: str = "l2Squared",
pre_filter: Optional[Dict] = None, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
@ -292,8 +295,11 @@ def _default_painless_scripting_query(
if not pre_filter: if not pre_filter:
pre_filter = MATCH_ALL_QUERY pre_filter = MATCH_ALL_QUERY
source = __get_painless_scripting_source(space_type, query_vector) source = __get_painless_scripting_source(
space_type, query_vector, vector_field=vector_field
)
return { return {
"size": k,
"query": { "query": {
"script_score": { "script_score": {
"query": pre_filter, "query": pre_filter,
@ -305,7 +311,7 @@ def _default_painless_scripting_query(
}, },
}, },
} }
} },
} }
@ -593,20 +599,20 @@ class OpenSearchVectorSearch(VectorStore):
space_type = _get_kwargs_value(kwargs, "space_type", "l2") space_type = _get_kwargs_value(kwargs, "space_type", "l2")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
search_query = _default_script_query( search_query = _default_script_query(
embedding, space_type, pre_filter, vector_field embedding, k, space_type, pre_filter, vector_field
) )
elif search_type == PAINLESS_SCRIPTING_SEARCH: elif search_type == PAINLESS_SCRIPTING_SEARCH:
space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared") space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared")
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY) pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
search_query = _default_painless_scripting_query( search_query = _default_painless_scripting_query(
embedding, space_type, pre_filter, vector_field embedding, k, space_type, pre_filter, vector_field
) )
else: else:
raise ValueError("Invalid `search_type` provided as an argument") raise ValueError("Invalid `search_type` provided as an argument")
response = self.client.search(index=self.index_name, body=search_query) response = self.client.search(index=self.index_name, body=search_query)
return [hit for hit in response["hits"]["hits"][:k]] return [hit for hit in response["hits"]["hits"]]
def max_marginal_relevance_search( def max_marginal_relevance_search(
self, self,