diff --git a/docs/modules/indexes/retrievers/examples/merger_retriever.ipynb b/docs/modules/indexes/retrievers/examples/merger_retriever.ipynb new file mode 100644 index 0000000000..0919dceec0 --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/merger_retriever.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "fc0db1bc", + "metadata": {}, + "source": [ + "# LOTR (Merger Retriever)\n", + "\n", + "Lord of the Retrievers, also known as MergerRetriever, takes a list of retrievers as input and merges the results of their get_relevant_documents() methods into a single list. The merged results will be a list of documents that are relevant to the query and that have been ranked by the different retrievers.\n", + "\n", + "The MergerRetriever class can be used to improve the accuracy of document retrieval in a number of ways. First, it can combine the results of multiple retrievers, which can help to reduce the risk of bias in the results. Second, it can rank the results of the different retrievers, which can help to ensure that the most relevant documents are returned first." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fbcc58f", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import chromadb\n", + "from langchain.retrievers.merger_retriever import MergerRetriever\n", + "from langchain.vectorstores import Chroma\n", + "from langchain.embeddings import HuggingFaceEmbeddings\n", + "from langchain.embeddings import OpenAIEmbeddings\n", + "from langchain.document_transformers import EmbeddingsRedundantFilter\n", + "from langchain.retrievers.document_compressors import DocumentCompressorPipeline\n", + "from langchain.retrievers import ContextualCompressionRetriever\n", + "\n", + "# Get 3 diff embeddings.\n", + "all_mini = HuggingFaceEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n", + "multi_qa_mini = HuggingFaceEmbeddings(model_name=\"multi-qa-MiniLM-L6-dot-v1\")\n", + "filter_embeddings = OpenAIEmbeddings()\n", + "\n", + "ABS_PATH = os.path.dirname(os.path.abspath(__file__))\n", + "DB_DIR = os.path.join(ABS_PATH, \"db\")\n", + "\n", + "# Instantiate 2 diff cromadb indexs, each one with a diff embedding.\n", + "client_settings = chromadb.config.Settings(\n", + " chroma_db_impl=\"duckdb+parquet\",\n", + " persist_directory=DB_DIR,\n", + " anonymized_telemetry=False,\n", + ")\n", + "db_all = Chroma(\n", + " collection_name=\"project_store_all\",\n", + " persist_directory=DB_DIR,\n", + " client_settings=client_settings,\n", + " embedding_function=all_mini,\n", + ")\n", + "db_multi_qa = Chroma(\n", + " collection_name=\"project_store_multi\",\n", + " persist_directory=DB_DIR,\n", + " client_settings=client_settings,\n", + " embedding_function=multi_qa_mini,\n", + ")\n", + "\n", + "# Define 2 diff retrievers with 2 diff embeddings and diff search type.\n", + "retriever_all = db_all.as_retriever(\n", + " search_type=\"similarity\", search_kwargs={\"k\": 5, \"include_metadata\": True}\n", + ")\n", + "retriever_multi_qa = db_multi_qa.as_retriever(\n", + " search_type=\"mmr\", search_kwargs={\"k\": 5, \"include_metadata\": True}\n", + ")\n", + "\n", + "# The Lord of the Retrievers will hold the ouput of boths retrievers and can be used as any other \n", + "# retriever on different types of chains.\n", + "lotr = MergerRetriever(retrievers=[retriever_all, retriever_multi_qa])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c152339d", + "metadata": {}, + "source": [ + "## Remove redundant results from the merged retrievers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "039faea6", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# We can remove redundant results from both retrievers using yet another embedding. \n", + "# Using multiples embeddings in diff steps could help reduce biases.\n", + "filter = EmbeddingsRedundantFilter(embeddings=filter_embeddings)\n", + "pipeline = DocumentCompressorPipeline(transformers=[filter])\n", + "compression_retriever = ContextualCompressionRetriever(\n", + " base_compressor=pipeline, base_retriever=lotr\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index bb3d2eac46..19b67f9245 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -6,6 +6,7 @@ from langchain.retrievers.contextual_compression import ContextualCompressionRet from langchain.retrievers.databerry import DataberryRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever from langchain.retrievers.knn import KNNRetriever +from langchain.retrievers.merger_retriever import MergerRetriever from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.pupmed import PubMedRetriever @@ -31,6 +32,7 @@ __all__ = [ "DataberryRetriever", "ElasticSearchBM25Retriever", "KNNRetriever", + "MergerRetriever", "MetalRetriever", "PineconeHybridSearchRetriever", "RemoteLangChainRetriever", diff --git a/langchain/retrievers/merger_retriever.py b/langchain/retrievers/merger_retriever.py new file mode 100644 index 0000000000..a9dccdc4ba --- /dev/null +++ b/langchain/retrievers/merger_retriever.py @@ -0,0 +1,110 @@ +from typing import List + +from langchain.schema import BaseRetriever, Document + + +class MergerRetriever(BaseRetriever): + """ + This class merges the results of multiple retrievers. + + Args: + retrievers: A list of retrievers to merge. + """ + + def __init__( + self, + retrievers: List[BaseRetriever], + ): + """ + Initialize the MergerRetriever class. + + Args: + retrievers: A list of retrievers to merge. + """ + + self.retrievers = retrievers + + def get_relevant_documents(self, query: str) -> List[Document]: + """ + Get the relevant documents for a given query. + + Args: + query: The query to search for. + + Returns: + A list of relevant documents. + """ + + # Merge the results of the retrievers. + merged_documents = self.merge_documents(query) + + return merged_documents + + async def aget_relevant_documents(self, query: str) -> List[Document]: + """ + Asynchronously get the relevant documents for a given query. + + Args: + query: The query to search for. + + Returns: + A list of relevant documents. + """ + + # Merge the results of the retrievers. + merged_documents = await self.amerge_documents(query) + + return merged_documents + + def merge_documents(self, query: str) -> List[Document]: + """ + Merge the results of the retrievers. + + Args: + query: The query to search for. + + Returns: + A list of merged documents. + """ + + # Get the results of all retrievers. + retriever_docs = [ + retriever.get_relevant_documents(query) for retriever in self.retrievers + ] + + # Merge the results of the retrievers. + merged_documents = [] + max_docs = max(len(docs) for docs in retriever_docs) + for i in range(max_docs): + for retriever, doc in zip(self.retrievers, retriever_docs): + if i < len(doc): + merged_documents.append(doc[i]) + + return merged_documents + + async def amerge_documents(self, query: str) -> List[Document]: + """ + Asynchronously merge the results of the retrievers. + + Args: + query: The query to search for. + + Returns: + A list of merged documents. + """ + + # Get the results of all retrievers. + retriever_docs = [ + await retriever.aget_relevant_documents(query) + for retriever in self.retrievers + ] + + # Merge the results of the retrievers. + merged_documents = [] + max_docs = max(len(docs) for docs in retriever_docs) + for i in range(max_docs): + for retriever, doc in zip(self.retrievers, retriever_docs): + if i < len(doc): + merged_documents.append(doc[i]) + + return merged_documents diff --git a/tests/integration_tests/retrievers/test_merger_retriever.py b/tests/integration_tests/retrievers/test_merger_retriever.py new file mode 100644 index 0000000000..f42f664478 --- /dev/null +++ b/tests/integration_tests/retrievers/test_merger_retriever.py @@ -0,0 +1,32 @@ +from langchain.embeddings import OpenAIEmbeddings +from langchain.retrievers.merger_retriever import MergerRetriever +from langchain.vectorstores import Chroma + + +def test_merger_retriever_get_relevant_docs() -> None: + """Test get_relevant_docs.""" + texts_group_a = [ + "This is a document about the Boston Celtics", + "Fly me to the moon is one of my favourite songs." + "I simply love going to the movies", + ] + texts_group_b = [ + "This is a document about the Poenix Suns", + "The Boston Celtics won the game by 20 points", + "Real stupidity beats artificial intelligence every time. TP", + ] + embeddings = OpenAIEmbeddings() + retriever_a = Chroma.from_texts(texts_group_a, embedding=embeddings).as_retriever( + search_kwargs={"k": 1} + ) + retriever_b = Chroma.from_texts(texts_group_b, embedding=embeddings).as_retriever( + search_kwargs={"k": 1} + ) + + # The Lord of the Retrievers. + lotr = MergerRetriever([retriever_a, retriever_b]) + + actual = lotr.get_relevant_documents("Tell me about the Celtics") + assert len(actual) == 2 + assert texts_group_a[0] in [d.page_content for d in actual] + assert texts_group_b[1] in [d.page_content for d in actual]