From 66d5a7e7cfebfcf8468cef85e57a80a3ea1af989 Mon Sep 17 00:00:00 2001 From: German Martin Date: Fri, 22 Sep 2023 12:33:20 -0300 Subject: [PATCH] Add async support to multi-query retriever. (#10873) Added async support to the MultiQueryRetriever class. --------- Co-authored-by: Nuno Campos --- .../langchain/retrievers/multi_query.py | 64 ++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index b99bb84a7e..5abdddc2e2 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -1,7 +1,11 @@ +import asyncio import logging from typing import List, Sequence -from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, +) from langchain.chains.llm import LLMChain from langchain.llms.base import BaseLLM from langchain.output_parsers.pydantic import PydanticOutputParser @@ -83,6 +87,64 @@ class MultiQueryRetriever(BaseRetriever): parser_key=parser_key, ) + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: AsyncCallbackManagerForRetrieverRun, + ) -> List[Document]: + """Get relevant documents given a user query. + + Args: + question: user query + + Returns: + Unique union of relevant documents from all generated queries + """ + queries = await self.agenerate_queries(query, run_manager) + documents = await self.aretrieve_documents(queries, run_manager) + return self.unique_union(documents) + + async def agenerate_queries( + self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[str]: + """Generate queries based upon user input. + + Args: + question: user query + + Returns: + List of LLM generated queries that are similar to the user input + """ + response = await self.llm_chain.acall( + inputs={"question": question}, callbacks=run_manager.get_child() + ) + lines = getattr(response["text"], self.parser_key, []) + if self.verbose: + logger.info(f"Generated queries: {lines}") + return lines + + async def aretrieve_documents( + self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun + ) -> List[Document]: + """Run all LLM generated queries. + + Args: + queries: query list + + Returns: + List of retrieved Documents + """ + document_lists = await asyncio.gather( + *( + self.retriever.aget_relevant_documents( + query, callbacks=run_manager.get_child() + ) + for query in queries + ) + ) + return [doc for docs in document_lists for doc in docs] + def _get_relevant_documents( self, query: str,