diff --git a/docs/extras/modules/data_connection/retrievers/ensemble.ipynb b/docs/extras/modules/data_connection/retrievers/ensemble.ipynb new file mode 100644 index 0000000000..f06349c7e5 --- /dev/null +++ b/docs/extras/modules/data_connection/retrievers/ensemble.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ensemble Retriever\n", + "\n", + "The `EnsembleRetriever` takes a list of retrievers as input and ensemble the results of their get_relevant_documents() methods and rerank the results based on the [Reciprocal Rank Fusion](https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf) algorithm.\n", + "\n", + "By leveraging the strengths of different algorithms, the `EnsembleRetriever` can achieve better performance than any single algorithm. \n", + "\n", + "The most common pattern is to combine a sparse retriever(like BM25) with a dense retriever(like Embedding similarity), because their strengths are complementary. It is also known as \"hybrid search\".The sparse retriever is good at finding relevant documents based on keywords, while the dense retriever is good at finding relevant documents based on semantic similarity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import BM25Retriever, EnsembleRetriever\n", + "from langchain.vectorstores import FAISS" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "doc_list = [\n", + " \"I like apples\",\n", + " \"I like oranges\",\n", + " \"Apples and oranges are fruits\",\n", + "]\n", + "\n", + "# initialize the bm25 retriever and faiss retriever\n", + "bm25_retriever = BM25Retriever.from_texts(doc_list)\n", + "bm25_retriever.k = 2\n", + "\n", + "embedding = OpenAIEmbeddings()\n", + "faiss_vectorstore = FAISS.from_texts(doc_list, embedding)\n", + "faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={\"k\": 2})\n", + "\n", + "# initialize the ensemble retriever\n", + "ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever], weights=[0.5, 0.5])" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='I like apples', metadata={}),\n", + " Document(page_content='Apples and oranges are fruits', metadata={})]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "docs = ensemble_retriever.get_relevant_documents(\"apples\")\n", + "docs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/retrievers/__init__.py b/libs/langchain/langchain/retrievers/__init__.py index 360a111e38..dc4cac35a5 100644 --- a/libs/langchain/langchain/retrievers/__init__.py +++ b/libs/langchain/langchain/retrievers/__init__.py @@ -6,6 +6,7 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.docarray import DocArrayRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever +from langchain.retrievers.ensemble import EnsembleRetriever from langchain.retrievers.google_cloud_enterprise_search import ( GoogleCloudEnterpriseSearchRetriever, ) @@ -64,4 +65,5 @@ __all__ = [ "ZepRetriever", "ZillizRetriever", "DocArrayRetriever", + "EnsembleRetriever", ] diff --git a/libs/langchain/langchain/retrievers/ensemble.py b/libs/langchain/langchain/retrievers/ensemble.py new file mode 100644 index 0000000000..b01a33fab6 --- /dev/null +++ b/libs/langchain/langchain/retrievers/ensemble.py @@ -0,0 +1,184 @@ +""" +Ensemble retriever that ensemble the results of +multiple retrievers by using weighted Reciprocal Rank Fusion +""" +from typing import Any, Dict, List + +from pydantic import root_validator + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) +from langchain.schema import BaseRetriever, Document + + +class EnsembleRetriever(BaseRetriever): + """ + This class ensemble the results of multiple retrievers by using rank fusion. + + Args: + retrievers: A list of retrievers to ensemble. + weights: A list of weights corresponding to the retrievers. Defaults to equal + weighting for all retrievers. + c: A constant added to the rank, controlling the balance between the importance + of high-ranked items and the consideration given to lower-ranked items. + Default is 60. + """ + + retrievers: List[BaseRetriever] + weights: List[float] + c: int = 60 + + @root_validator(pre=True) + def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not values.get("weights"): + n_retrievers = len(values["retrievers"]) + values["weights"] = [1 / n_retrievers] * n_retrievers + return values + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """ + Get the relevant documents for a given query. + + Args: + query: The query to search for. + + Returns: + A list of reranked documents. + """ + + # Get fused result of the retrievers. + fused_documents = self.rank_fusion(query, run_manager) + + return fused_documents + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + ) -> List[Document]: + """ + Asynchronously get the relevant documents for a given query. + + Args: + query: The query to search for. + + Returns: + A list of reranked documents. + """ + + # Get fused result of the retrievers. + fused_documents = await self.arank_fusion(query, run_manager) + + return fused_documents + + def rank_fusion( + self, query: str, run_manager: CallbackManagerForRetrieverRun + ) -> List[Document]: + """ + Retrieve the results of the retrievers and use rank_fusion_func to get + the final result. + + Args: + query: The query to search for. + + Returns: + A list of reranked documents. + """ + + # Get the results of all retrievers. + retriever_docs = [ + retriever.get_relevant_documents( + query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") + ) + for i, retriever in enumerate(self.retrievers) + ] + + # apply rank fusion + fused_documents = self.weighted_reciprocal_rank(retriever_docs) + + return fused_documents + + async def arank_fusion( + self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """ + Asynchronously retrieve the results of the retrievers + and use rank_fusion_func to get the final result. + + Args: + query: The query to search for. + + Returns: + A list of reranked documents. + """ + + # Get the results of all retrievers. + retriever_docs = [ + await retriever.aget_relevant_documents( + query, callbacks=run_manager.get_child(tag=f"retriever_{i+1}") + ) + for i, retriever in enumerate(self.retrievers) + ] + + # apply rank fusion + fused_documents = self.weighted_reciprocal_rank(retriever_docs) + + return fused_documents + + def weighted_reciprocal_rank( + self, doc_lists: List[List[Document]] + ) -> List[Document]: + """ + Perform weighted Reciprocal Rank Fusion on multiple rank lists. + You can find more details about RRF here: + https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf + + Args: + doc_lists: A list of rank lists, where each rank list contains unique items. + + Returns: + list: The final aggregated list of items sorted by their weighted RRF + scores in descending order. + """ + if len(doc_lists) != len(self.weights): + raise ValueError( + "Number of rank lists must be equal to the number of weights." + ) + + # Create a union of all unique documents in the input doc_lists + all_documents = set() + for doc_list in doc_lists: + for doc in doc_list: + all_documents.add(doc.page_content) + + # Initialize the RRF score dictionary for each document + rrf_score_dic = {doc: 0.0 for doc in all_documents} + + # Calculate RRF scores for each document + for doc_list, weight in zip(doc_lists, self.weights): + for rank, doc in enumerate(doc_list, start=1): + rrf_score = weight * (1 / (rank + self.c)) + rrf_score_dic[doc.page_content] += rrf_score + + # Sort documents by their RRF scores in descending order + sorted_documents = sorted( + rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True + ) + + # Map the sorted page_content back to the original document objects + page_content_to_doc_map = { + doc.page_content: doc for doc_list in doc_lists for doc in doc_list + } + sorted_docs = [ + page_content_to_doc_map[page_content] for page_content in sorted_documents + ] + + return sorted_docs diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py new file mode 100644 index 0000000000..2488ff0643 --- /dev/null +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -0,0 +1,42 @@ +import pytest + +from langchain.retrievers.bm25 import BM25Retriever +from langchain.retrievers.ensemble import EnsembleRetriever +from langchain.schema import Document + + +@pytest.mark.requires("rank_bm25") +def test_ensemble_retriever_get_relevant_docs() -> None: + doc_list = [ + "I like apples", + "I like oranges", + "Apples and oranges are fruits", + ] + + dummy_retriever = BM25Retriever.from_texts(doc_list) + dummy_retriever.k = 1 + + ensemble_retriever = EnsembleRetriever( + retrievers=[dummy_retriever, dummy_retriever] + ) + docs = ensemble_retriever.get_relevant_documents("I like apples") + assert len(docs) == 1 + + +@pytest.mark.requires("rank_bm25") +def test_weighted_reciprocal_rank() -> None: + doc1 = Document(page_content="1") + doc2 = Document(page_content="2") + + dummy_retriever = BM25Retriever.from_texts(["1", "2"]) + ensemble_retriever = EnsembleRetriever( + retrievers=[dummy_retriever, dummy_retriever], weights=[0.4, 0.5], c=0 + ) + result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]]) + assert result[0].page_content == "2" + assert result[1].page_content == "1" + + ensemble_retriever.weights = [0.5, 0.4] + result = ensemble_retriever.weighted_reciprocal_rank([[doc1, doc2], [doc2, doc1]]) + assert result[0].page_content == "1" + assert result[1].page_content == "2"