@ -61,6 +61,8 @@ class MultiQueryRetriever(BaseRetriever):
llm_chain : LLMChain
verbose : bool = True
parser_key : str = " lines "
include_original : bool = False
""" Whether to include the original query in the list of generated queries. """
@classmethod
def from_llm (
@ -69,12 +71,15 @@ class MultiQueryRetriever(BaseRetriever):
llm : BaseLLM ,
prompt : PromptTemplate = DEFAULT_QUERY_PROMPT ,
parser_key : str = " lines " ,
include_original : bool = False ,
) - > " MultiQueryRetriever " :
""" Initialize from llm using default template.
Args :
retriever : retriever to query documents from
llm : llm for query generation using DEFAULT_QUERY_PROMPT
include_original : Whether to include the original query in the list of
generated queries .
Returns :
MultiQueryRetriever
@ -85,6 +90,7 @@ class MultiQueryRetriever(BaseRetriever):
retriever = retriever ,
llm_chain = llm_chain ,
parser_key = parser_key ,
include_original = include_original ,
)
async def _aget_relevant_documents (
@ -102,6 +108,8 @@ class MultiQueryRetriever(BaseRetriever):
Unique union of relevant documents from all generated queries
"""
queries = await self . agenerate_queries ( query , run_manager )
if self . include_original :
queries . append ( query )
documents = await self . aretrieve_documents ( queries , run_manager )
return self . unique_union ( documents )
@ -160,6 +168,8 @@ class MultiQueryRetriever(BaseRetriever):
Unique union of relevant documents from all generated queries
"""
queries = self . generate_queries ( query , run_manager )
if self . include_original :
queries . append ( query )
documents = self . retrieve_documents ( queries , run_manager )
return self . unique_union ( documents )