You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/retrievers/merger_retriever.py

111 lines
2.9 KiB
Python

from typing import List
from langchain.schema import BaseRetriever, Document
class MergerRetriever(BaseRetriever):
"""
This class merges the results of multiple retrievers.
Args:
retrievers: A list of retrievers to merge.
"""
def __init__(
self,
retrievers: List[BaseRetriever],
):
"""
Initialize the MergerRetriever class.
Args:
retrievers: A list of retrievers to merge.
"""
self.retrievers = retrievers
def get_relevant_documents(self, query: str) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query)
return merged_documents
async def aget_relevant_documents(self, query: str) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query)
return merged_documents
def merge_documents(self, query: str) -> List[Document]:
"""
Merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(query) for retriever in self.retrievers
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
async def amerge_documents(self, query: str) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
await retriever.aget_relevant_documents(query)
for retriever in self.retrievers
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents