forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
2.9 KiB
Python
111 lines
2.9 KiB
Python
from typing import List
|
|
|
|
from langchain.schema import BaseRetriever, Document
|
|
|
|
|
|
class MergerRetriever(BaseRetriever):
|
|
"""
|
|
This class merges the results of multiple retrievers.
|
|
|
|
Args:
|
|
retrievers: A list of retrievers to merge.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
retrievers: List[BaseRetriever],
|
|
):
|
|
"""
|
|
Initialize the MergerRetriever class.
|
|
|
|
Args:
|
|
retrievers: A list of retrievers to merge.
|
|
"""
|
|
|
|
self.retrievers = retrievers
|
|
|
|
def get_relevant_documents(self, query: str) -> List[Document]:
|
|
"""
|
|
Get the relevant documents for a given query.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
|
|
Returns:
|
|
A list of relevant documents.
|
|
"""
|
|
|
|
# Merge the results of the retrievers.
|
|
merged_documents = self.merge_documents(query)
|
|
|
|
return merged_documents
|
|
|
|
async def aget_relevant_documents(self, query: str) -> List[Document]:
|
|
"""
|
|
Asynchronously get the relevant documents for a given query.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
|
|
Returns:
|
|
A list of relevant documents.
|
|
"""
|
|
|
|
# Merge the results of the retrievers.
|
|
merged_documents = await self.amerge_documents(query)
|
|
|
|
return merged_documents
|
|
|
|
def merge_documents(self, query: str) -> List[Document]:
|
|
"""
|
|
Merge the results of the retrievers.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
|
|
Returns:
|
|
A list of merged documents.
|
|
"""
|
|
|
|
# Get the results of all retrievers.
|
|
retriever_docs = [
|
|
retriever.get_relevant_documents(query) for retriever in self.retrievers
|
|
]
|
|
|
|
# Merge the results of the retrievers.
|
|
merged_documents = []
|
|
max_docs = max(len(docs) for docs in retriever_docs)
|
|
for i in range(max_docs):
|
|
for retriever, doc in zip(self.retrievers, retriever_docs):
|
|
if i < len(doc):
|
|
merged_documents.append(doc[i])
|
|
|
|
return merged_documents
|
|
|
|
async def amerge_documents(self, query: str) -> List[Document]:
|
|
"""
|
|
Asynchronously merge the results of the retrievers.
|
|
|
|
Args:
|
|
query: The query to search for.
|
|
|
|
Returns:
|
|
A list of merged documents.
|
|
"""
|
|
|
|
# Get the results of all retrievers.
|
|
retriever_docs = [
|
|
await retriever.aget_relevant_documents(query)
|
|
for retriever in self.retrievers
|
|
]
|
|
|
|
# Merge the results of the retrievers.
|
|
merged_documents = []
|
|
max_docs = max(len(docs) for docs in retriever_docs)
|
|
for i in range(max_docs):
|
|
for retriever, doc in zip(self.retrievers, retriever_docs):
|
|
if i < len(doc):
|
|
merged_documents.append(doc[i])
|
|
|
|
return merged_documents
|