From 3e835a1aa115bf98e87d6dc705258cb190dd1dba Mon Sep 17 00:00:00 2001 From: shimajiroxyz Date: Tue, 18 Jun 2024 12:29:17 +0900 Subject: [PATCH] langchain: add id_key option to EnsembleRetriever for metadata-based document merging (#22950) **Description:** - What I changed - By specifying the `id_key` during the initialization of `EnsembleRetriever`, it is now possible to determine which documents to merge scores for based on the value corresponding to the `id_key` element in the metadata, instead of `page_content`. Below is an example of how to use the modified `EnsembleRetriever`: ```python retriever = EnsembleRetriever(retrievers=[ret1, ret2], id_key="id") # The Document returned by each retriever must keep the "id" key in its metadata. ``` - Additionally, I added a script to easily test the behavior of the `invoke` method of the modified `EnsembleRetriever`. - Why I changed - There are cases where you may want to calculate scores by treating Documents with different `page_content` as the same when using `EnsembleRetriever`. For example, when you want to ensemble the search results of the same document described in two different languages. - The previous `EnsembleRetriever` used `page_content` as the basis for score aggregation, making the above usage difficult. Therefore, the score is now calculated based on the specified key value in the Document's metadata. **Twitter handle:** @shimajiroxyz --- .../langchain/retrievers/ensemble.py | 20 ++++- .../unit_tests/retrievers/test_ensemble.py | 88 +++++++++++++++++++ 2 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/retrievers/test_ensemble.py diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py index 61f2bfdcb1..0eb779544d 100644 --- a/libs/langchain/langchain/retrievers/ensemble.py +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -66,11 +66,14 @@ class EnsembleRetriever(BaseRetriever): c: A constant added to the rank, controlling the balance between the importance of high-ranked items and the consideration given to lower-ranked items. Default is 60. + id_key: The key in the document's metadata used to determine unique documents. + If not specified, page_content is used. """ retrievers: List[RetrieverLike] weights: List[float] c: int = 60 + id_key: Optional[str] = None @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -305,13 +308,24 @@ class EnsembleRetriever(BaseRetriever): rrf_score: Dict[str, float] = defaultdict(float) for doc_list, weight in zip(doc_lists, self.weights): for rank, doc in enumerate(doc_list, start=1): - rrf_score[doc.page_content] += weight / (rank + self.c) + rrf_score[ + doc.page_content + if self.id_key is None + else doc.metadata[self.id_key] + ] += weight / (rank + self.c) # Docs are deduplicated by their contents then sorted by their scores all_docs = chain.from_iterable(doc_lists) sorted_docs = sorted( - unique_by_key(all_docs, lambda doc: doc.page_content), + unique_by_key( + all_docs, + lambda doc: doc.page_content + if self.id_key is None + else doc.metadata[self.id_key], + ), reverse=True, - key=lambda doc: rrf_score[doc.page_content], + key=lambda doc: rrf_score[ + doc.page_content if self.id_key is None else doc.metadata[self.id_key] + ], ) return sorted_docs diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py new file mode 100644 index 0000000000..4c5e9837c0 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever + +from langchain.retrievers.ensemble import EnsembleRetriever + + +class MockRetriever(BaseRetriever): + docs: List[Document] + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: Optional[CallbackManagerForRetrieverRun] = None, + ) -> List[Document]: + """Return the documents""" + return self.docs + + +def test_invoke() -> None: + documents1 = [ + Document(page_content="a", metadata={"id": 1}), + Document(page_content="b", metadata={"id": 2}), + Document(page_content="c", metadata={"id": 3}), + ] + documents2 = [Document(page_content="b")] + + retriever1 = MockRetriever(docs=documents1) + retriever2 = MockRetriever(docs=documents2) + + ensemble_retriever = EnsembleRetriever( + retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None + ) + ranked_documents = ensemble_retriever.invoke("_") + + # The document with page_content "b" in documents2 + # will be merged with the document with page_content "b" + # in documents1, so the length of ranked_documents should be 3. + # Additionally, the document with page_content "b" will be ranked 1st. + assert len(ranked_documents) == 3 + assert ranked_documents[0].page_content == "b" + + documents1 = [ + Document(page_content="a", metadata={"id": 1}), + Document(page_content="b", metadata={"id": 2}), + Document(page_content="c", metadata={"id": 3}), + ] + documents2 = [Document(page_content="d")] + + retriever1 = MockRetriever(docs=documents1) + retriever2 = MockRetriever(docs=documents2) + + ensemble_retriever = EnsembleRetriever( + retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key=None + ) + ranked_documents = ensemble_retriever.invoke("_") + + # The document with page_content "d" in documents2 will not be merged + # with any document in documents1, so the length of ranked_documents + # should be 4. The document with page_content "a" and the document + # with page_content "d" will have the same score, but the document + # with page_content "a" will be ranked 1st because retriever1 has a smaller index. + assert len(ranked_documents) == 4 + assert ranked_documents[0].page_content == "a" + + documents1 = [ + Document(page_content="a", metadata={"id": 1}), + Document(page_content="b", metadata={"id": 2}), + Document(page_content="c", metadata={"id": 3}), + ] + documents2 = [Document(page_content="d", metadata={"id": 2})] + + retriever1 = MockRetriever(docs=documents1) + retriever2 = MockRetriever(docs=documents2) + + ensemble_retriever = EnsembleRetriever( + retrievers=[retriever1, retriever2], weights=[0.5, 0.5], id_key="id" + ) + ranked_documents = ensemble_retriever.invoke("_") + + # Since id_key is specified, the document with id 2 will be merged. + # Therefore, the length of ranked_documents should be 3. + # Additionally, the document with page_content "b" will be ranked 1st. + assert len(ranked_documents) == 3 + assert ranked_documents[0].page_content == "b"