added rrf argument in ApproxRetrievalStrategy class __init__() (#11987)

- **Description: To handle the hybrid search with RRF(Reciprocal Rank
Fusion) in the Elasticsearch, rrf argument was added for adjusting
'rank_constant' and 'window_size' to combine multiple result sets with
different relevance indicators into a single result set. (ref:
https://www.elastic.co/kr/blog/whats-new-elastic-enterprise-search-8-9-0),
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** No dependencies changed,
  - **Tag maintainer:** @baskaryan,

Nice to meet you,
I'm a newbie for contributions and it's my first PR.

I only changed the langchain/vectorstores/elasticsearch.py file.
I did make format&lint 
I got this message,
```shell
make lint_diff  
./scripts/check_pydantic.sh .
./scripts/check_imports.sh
poetry run ruff .
[ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run black langchain/vectorstores/elasticsearch.py --check
All done!  🍰 
1 file would be left unchanged.
[ "langchain/vectorstores/elasticsearch.py" = "" ] || poetry run mypy langchain/vectorstores/elasticsearch.py
langchain/__init__.py: error: Source file found twice under different module names: "mvp.nlp.langchain.libs.langchain.langchain" and "langchain"
Found 1 error in 1 file (errors prevented further checking)
make: *** [lint_diff] Error 2
```

Thank you

---------

Co-authored-by: 황중원 <jwhwang@amorepacific.com>
pull/12452/head
HwangJohn 11 months ago committed by GitHub
parent 2c58dca5f0
commit d38c8369b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -117,10 +117,16 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
self,
query_model_id: Optional[str] = None,
hybrid: Optional[bool] = False,
rrf: Optional[Union[dict, bool]] = True,
):
self.query_model_id = query_model_id
self.hybrid = hybrid
# RRF has two optional parameters
# 'rank_constant', 'window_size'
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
self.rrf = rrf
def query(
self,
query_vector: Union[List[float], None],
@ -161,8 +167,10 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
# If hybrid, add a query to the knn query
# RRF is used to even the score from the knn query and text query
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
# https://www.elastic.co/guide/en/elasticsearch/reference/current/rrf.html
if self.hybrid:
return {
query_body = {
"knn": knn,
"query": {
"bool": {
@ -178,8 +186,14 @@ class ApproxRetrievalStrategy(BaseRetrievalStrategy):
"filter": filter,
}
},
"rank": {"rrf": {}},
}
if isinstance(self.rrf, dict):
query_body["rank"] = {"rrf": self.rrf}
elif isinstance(self.rrf, bool) and self.rrf is True:
query_body["rank"] = {"rrf": {}}
return query_body
else:
return {"knn": knn}
@ -587,6 +601,7 @@ class ElasticsearchStore(VectorStore):
self,
query: str,
k: int = 4,
fetch_k: int = 50,
filter: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[Document]:
@ -595,6 +610,7 @@ class ElasticsearchStore(VectorStore):
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k (int): Number of Documents to fetch to pass to knn num_candidates.
filter: Array of Elasticsearch filter clauses to apply to the query.
Returns:
@ -602,7 +618,9 @@ class ElasticsearchStore(VectorStore):
in descending order of similarity.
"""
results = self._search(query=query, k=k, filter=filter, **kwargs)
results = self._search(
query=query, k=k, fetch_k=fetch_k, filter=filter, **kwargs
)
return [doc for doc, _ in results]
def max_marginal_relevance_search(
@ -1187,6 +1205,7 @@ class ElasticsearchStore(VectorStore):
def ApproxRetrievalStrategy(
query_model_id: Optional[str] = None,
hybrid: Optional[bool] = False,
rrf: Optional[Union[dict, bool]] = True,
) -> "ApproxRetrievalStrategy":
"""Used to perform approximate nearest neighbor search
using the HNSW algorithm.
@ -1209,8 +1228,16 @@ class ElasticsearchStore(VectorStore):
hybrid: Optional. If True, will perform a hybrid search
using both the knn query and a text query.
Defaults to False.
rrf: Optional. rrf is Reciprocal Rank Fusion.
When `hybrid` is True,
and `rrf` is True, then rrf: {}.
and `rrf` is False, then rrf is omitted.
and isinstance(rrf, dict) is True, then pass in the dict values.
rrf could be passed for adjusting 'rank_constant' and 'window_size'.
"""
return ApproxRetrievalStrategy(query_model_id=query_model_id, hybrid=hybrid)
return ApproxRetrievalStrategy(
query_model_id=query_model_id, hybrid=hybrid, rrf=rrf
)
@staticmethod
def SparseVectorRetrievalStrategy(

@ -481,6 +481,115 @@ class TestElasticsearch:
output = docsearch.similarity_search("foo", k=1, custom_query=assert_query)
assert output == [Document(page_content="foo")]
def test_similarity_search_approx_with_hybrid_search_rrf(
self, es_client: Any, elasticsearch_connection: dict, index_name: str
) -> None:
"""Test end to end construction and rrf hybrid search with metadata."""
from functools import partial
from typing import Optional
# 1. check query_body is okay
rrf_test_cases: List[Optional[Union[dict, bool]]] = [
True,
False,
{"rank_constant": 1, "window_size": 5},
]
for rrf_test_case in rrf_test_cases:
texts = ["foo", "bar", "baz"]
docsearch = ElasticsearchStore.from_texts(
texts,
FakeEmbeddings(),
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.ApproxRetrievalStrategy(
hybrid=True, rrf=rrf_test_case
),
)
def assert_query(
query_body: dict,
query: str,
rrf: Optional[Union[dict, bool]] = True,
) -> dict:
cmp_query_body = {
"knn": {
"field": "vector",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
0.0,
],
},
"query": {
"bool": {
"filter": [],
"must": [{"match": {"text": {"query": "foo"}}}],
}
},
}
if isinstance(rrf, dict):
cmp_query_body["rank"] = {"rrf": rrf}
elif isinstance(rrf, bool) and rrf is True:
cmp_query_body["rank"] = {"rrf": {}}
assert query_body == cmp_query_body
return query_body
## without fetch_k parameter
output = docsearch.similarity_search(
"foo", k=3, custom_query=partial(assert_query, rrf=rrf_test_case)
)
# 2. check query result is okay
es_output = es_client.search(
index=index_name,
query={
"bool": {
"filter": [],
"must": [{"match": {"text": {"query": "foo"}}}],
}
},
knn={
"field": "vector",
"filter": [],
"k": 3,
"num_candidates": 50,
"query_vector": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
},
size=3,
rank={"rrf": {"rank_constant": 1, "window_size": 5}},
)
assert [o.page_content for o in output] == [
e["_source"]["text"] for e in es_output["hits"]["hits"]
]
# 3. check rrf default option is okay
docsearch = ElasticsearchStore.from_texts(
texts,
FakeEmbeddings(),
**elasticsearch_connection,
index_name=index_name,
strategy=ElasticsearchStore.ApproxRetrievalStrategy(hybrid=True),
)
## with fetch_k parameter
output = docsearch.similarity_search(
"foo", k=3, fetch_k=50, custom_query=assert_query
)
def test_similarity_search_approx_with_custom_query_fn(
self, elasticsearch_connection: dict, index_name: str
) -> None:

Loading…
Cancel
Save