mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Issue 8081 Fix query results size bug. Other bug: pass vector_field param. (#8085)
@baskaryan #8081 Likely the reason why the issue occurred is that OpenSearch's default k is 10, so it needs to be specified. Here's a similar question about its cousin ElasticSearch https://discuss.elastic.co/t/elasticsearch-returns-only-10-records-but-the-hit-is-507/136605 I tested this manually and also fixed the same issue in `_default_painless_scripting_query`. In addition, `_default_painless_scripting_query` was not passing the `vector_field` name to a sub call, so I fixed that too. ![image](https://github.com/hwchase17/langchain/assets/32244272/cfb7aad1-f701-49d9-9beb-a723aa276817) I also tested this in the aws opensearch developer tools. ![image](https://github.com/hwchase17/langchain/assets/32244272/24544682-1578-4bbb-9eb5-980463c5b41b) --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
812419d946
commit
c7ea6e9ff8
@ -235,6 +235,7 @@ def _approximate_search_query_with_efficient_filter(
|
|||||||
|
|
||||||
def _default_script_query(
|
def _default_script_query(
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
|
k: int = 4,
|
||||||
space_type: str = "l2",
|
space_type: str = "l2",
|
||||||
pre_filter: Optional[Dict] = None,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
@ -245,6 +246,7 @@ def _default_script_query(
|
|||||||
pre_filter = MATCH_ALL_QUERY
|
pre_filter = MATCH_ALL_QUERY
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"size": k,
|
||||||
"query": {
|
"query": {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
"query": pre_filter,
|
"query": pre_filter,
|
||||||
@ -258,7 +260,7 @@ def _default_script_query(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +285,7 @@ def __get_painless_scripting_source(
|
|||||||
|
|
||||||
def _default_painless_scripting_query(
|
def _default_painless_scripting_query(
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
|
k: int = 4,
|
||||||
space_type: str = "l2Squared",
|
space_type: str = "l2Squared",
|
||||||
pre_filter: Optional[Dict] = None,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
@ -292,8 +295,11 @@ def _default_painless_scripting_query(
|
|||||||
if not pre_filter:
|
if not pre_filter:
|
||||||
pre_filter = MATCH_ALL_QUERY
|
pre_filter = MATCH_ALL_QUERY
|
||||||
|
|
||||||
source = __get_painless_scripting_source(space_type, query_vector)
|
source = __get_painless_scripting_source(
|
||||||
|
space_type, query_vector, vector_field=vector_field
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
|
"size": k,
|
||||||
"query": {
|
"query": {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
"query": pre_filter,
|
"query": pre_filter,
|
||||||
@ -305,7 +311,7 @@ def _default_painless_scripting_query(
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -593,20 +599,20 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
||||||
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
|
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
|
||||||
search_query = _default_script_query(
|
search_query = _default_script_query(
|
||||||
embedding, space_type, pre_filter, vector_field
|
embedding, k, space_type, pre_filter, vector_field
|
||||||
)
|
)
|
||||||
elif search_type == PAINLESS_SCRIPTING_SEARCH:
|
elif search_type == PAINLESS_SCRIPTING_SEARCH:
|
||||||
space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared")
|
space_type = _get_kwargs_value(kwargs, "space_type", "l2Squared")
|
||||||
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
|
pre_filter = _get_kwargs_value(kwargs, "pre_filter", MATCH_ALL_QUERY)
|
||||||
search_query = _default_painless_scripting_query(
|
search_query = _default_painless_scripting_query(
|
||||||
embedding, space_type, pre_filter, vector_field
|
embedding, k, space_type, pre_filter, vector_field
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid `search_type` provided as an argument")
|
raise ValueError("Invalid `search_type` provided as an argument")
|
||||||
|
|
||||||
response = self.client.search(index=self.index_name, body=search_query)
|
response = self.client.search(index=self.index_name, body=search_query)
|
||||||
|
|
||||||
return [hit for hit in response["hits"]["hits"][:k]]
|
return [hit for hit in response["hits"]["hits"]]
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
def max_marginal_relevance_search(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user