langchain[patch]: Simplify ensemble retriever (#14427)

- **Description:** code simplification to improve readability and remove
unnecessary memory allocations.
  - **Tag maintainer**: @baskaryan, @eyurtsev, @hwchase17.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Ahmed Moubtahij 2024-03-29 19:49:49 -04:00 committed by GitHub
parent b36f4147b0
commit f5d4ce840f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,20 @@ Ensemble retriever that ensemble the results of
multiple retrievers by using weighted Reciprocal Rank Fusion
"""
import asyncio
from typing import Any, Dict, List, Optional, cast
from collections import defaultdict
from collections.abc import Hashable
from itertools import chain
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
TypeVar,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForRetrieverRun,
@ -20,6 +33,17 @@ from langchain_core.runnables.utils import (
get_unique_config_specs,
)
T = TypeVar("T")
H = TypeVar("H", bound=Hashable)
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
seen = set()
for e in iterable:
if (k := key(e)) not in seen:
seen.add(k)
yield e
class EnsembleRetriever(BaseRetriever):
"""Retriever that ensembles the multiple retrievers.
@ -267,32 +291,18 @@ class EnsembleRetriever(BaseRetriever):
"Number of rank lists must be equal to the number of weights."
)
# Create a union of all unique documents in the input doc_lists
all_documents = set()
for doc_list in doc_lists:
for doc in doc_list:
all_documents.add(doc.page_content)
# Initialize the RRF score dictionary for each document
rrf_score_dic = {doc: 0.0 for doc in all_documents}
# Calculate RRF scores for each document
# Associate each doc's content with its RRF score for later sorting by it
# Duplicated contents across retrievers are collapsed & scored cumulatively
rrf_score: Dict[str, float] = defaultdict(float)
for doc_list, weight in zip(doc_lists, self.weights):
for rank, doc in enumerate(doc_list, start=1):
rrf_score = weight * (1 / (rank + self.c))
rrf_score_dic[doc.page_content] += rrf_score
rrf_score[doc.page_content] += weight / (rank + self.c)
# Sort documents by their RRF scores in descending order
sorted_documents = sorted(
rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
# Docs are deduplicated by their contents then sorted by their scores
all_docs = chain.from_iterable(doc_lists)
sorted_docs = sorted(
unique_by_key(all_docs, lambda doc: doc.page_content),
reverse=True,
key=lambda doc: rrf_score[doc.page_content],
)
# Map the sorted page_content back to the original document objects
page_content_to_doc_map = {
doc.page_content: doc for doc_list in doc_lists for doc in doc_list
}
sorted_docs = [
page_content_to_doc_map[page_content] for page_content in sorted_documents
]
return sorted_docs