Add the ability to run the map_reduce chains process results step as async (#6181)

This will add the ability to add an AsyncCallbackManager (handler) for
the reducer chain, which would be able to stream the tokens via the
`async def on_llm_new_token` callback method



Fixes # (issue)
[5532](https://github.com/hwchase17/langchain/issues/5532)


 @hwchase17  @agola11 
The following code snippet explains how this change would be used to
enable `reduce_llm` with streaming support in a `map_reduce` chain

I have tested this change and it works for the streaming use-case of
reducer responses. I am happy to share more information if this makes
solution sense.

```

AsyncHandler
..........................
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
    """Callback handler for streaming LLM responses."""

    def __init__(self, websocket):
        self.websocket = websocket
    
    # This callback method is to be executed in async
    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        resp = ChatResponse(sender="bot", message=token, type="stream")
        await self.websocket.send_json(resp.dict())


Chain
..........
stream_handler = StreamingLLMCallbackHandler(websocket)
stream_manager = AsyncCallbackManager([stream_handler])

streaming_llm = ChatOpenAI(
        streaming=True,
        callback_manager=stream_manager,
        verbose=False,
        temperature=0,
    )
    main_llm = OpenAI(
        temperature=0,
        verbose=False,
    )

    doc_chain = load_qa_chain(
        llm=main_llm,
        reduce_llm=streaming_llm,
        chain_type="map_reduce", 
        callback_manager=manager
    )
    qa_chain = ConversationalRetrievalChain(
        retriever=vectorstore.as_retriever(),
        combine_docs_chain=doc_chain,
        question_generator=question_generator,
        callback_manager=manager,
    )
    
    # Here `acall` will trigger `acombine_docs` on `map_reduce` which should then call `_aprocess_result` which in turn will call `self.combine_document_chain.arun` hence async callback will be awaited
    result = await qa_chain.acall(
         {"question": question, "chat_history": chat_history}
      )
```
master
Vijay 11 months ago committed by GitHub
parent e0dea577ee
commit 2b3b4e0f60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -163,16 +163,18 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
callbacks=callbacks,
)
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
return await self._aprocess_results(
results, docs, callbacks=callbacks, **kwargs
)
def _process_results(
def _process_results_common(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
) -> Tuple[List[Document], dict]:
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
@ -201,11 +203,39 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
extra_return_dict = {"intermediate_steps": _results}
else:
extra_return_dict = {}
return result_docs, extra_return_dict
def _process_results(
self,
results: List[Dict],
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, token_max, callbacks=callbacks, **kwargs
)
output = self.combine_document_chain.run(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
async def _aprocess_results(
self,
results: List[Dict],
docs: List[Document],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
result_docs, extra_return_dict = self._process_results_common(
results, docs, callbacks=callbacks, **kwargs
)
output = await self.combine_document_chain.arun(
input_documents=result_docs, callbacks=callbacks, **kwargs
)
return output, extra_return_dict
@property
def _chain_type(self) -> str:
return "map_reduce_documents_chain"

Loading…
Cancel
Save