From 2b3b4e0f600285d139d4e10a498530810d15b288 Mon Sep 17 00:00:00 2001 From: Vijay Date: Sun, 18 Jun 2023 22:19:56 +0200 Subject: [PATCH] Add the ability to run the map_reduce chains process results step as async (#6181) This will add the ability to add an AsyncCallbackManager (handler) for the reducer chain, which would be able to stream the tokens via the `async def on_llm_new_token` callback method Fixes # (issue) [5532](https://github.com/hwchase17/langchain/issues/5532) @hwchase17 @agola11 The following code snippet explains how this change would be used to enable `reduce_llm` with streaming support in a `map_reduce` chain I have tested this change and it works for the streaming use-case of reducer responses. I am happy to share more information if this makes solution sense. ``` AsyncHandler .......................... class StreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" def __init__(self, websocket): self.websocket = websocket # This callback method is to be executed in async async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(sender="bot", message=token, type="stream") await self.websocket.send_json(resp.dict()) Chain .......... stream_handler = StreamingLLMCallbackHandler(websocket) stream_manager = AsyncCallbackManager([stream_handler]) streaming_llm = ChatOpenAI( streaming=True, callback_manager=stream_manager, verbose=False, temperature=0, ) main_llm = OpenAI( temperature=0, verbose=False, ) doc_chain = load_qa_chain( llm=main_llm, reduce_llm=streaming_llm, chain_type="map_reduce", callback_manager=manager ) qa_chain = ConversationalRetrievalChain( retriever=vectorstore.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator, callback_manager=manager, ) # Here `acall` will trigger `acombine_docs` on `map_reduce` which should then call `_aprocess_result` which in turn will call `self.combine_document_chain.arun` hence async callback will be awaited result = await qa_chain.acall( {"question": question, "chat_history": chat_history} ) ``` --- .../chains/combine_documents/map_reduce.py | 36 +++++++++++++++++-- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 06e87e03..84e49296 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -163,16 +163,18 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], callbacks=callbacks, ) - return self._process_results(results, docs, callbacks=callbacks, **kwargs) + return await self._aprocess_results( + results, docs, callbacks=callbacks, **kwargs + ) - def _process_results( + def _process_results_common( self, results: List[Dict], docs: List[Document], token_max: int = 3000, callbacks: Callbacks = None, **kwargs: Any, - ) -> Tuple[str, dict]: + ) -> Tuple[List[Document], dict]: question_result_key = self.llm_chain.output_key result_docs = [ Document(page_content=r[question_result_key], metadata=docs[i].metadata) @@ -201,11 +203,39 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): extra_return_dict = {"intermediate_steps": _results} else: extra_return_dict = {} + return result_docs, extra_return_dict + + def _process_results( + self, + results: List[Dict], + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[str, dict]: + result_docs, extra_return_dict = self._process_results_common( + results, docs, token_max, callbacks=callbacks, **kwargs + ) output = self.combine_document_chain.run( input_documents=result_docs, callbacks=callbacks, **kwargs ) return output, extra_return_dict + async def _aprocess_results( + self, + results: List[Dict], + docs: List[Document], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[str, dict]: + result_docs, extra_return_dict = self._process_results_common( + results, docs, callbacks=callbacks, **kwargs + ) + output = await self.combine_document_chain.arun( + input_documents=result_docs, callbacks=callbacks, **kwargs + ) + return output, extra_return_dict + @property def _chain_type(self) -> str: return "map_reduce_documents_chain"