mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add the ability to run the map_reduce chains process results step as async (#6181)
This will add the ability to add an AsyncCallbackManager (handler) for the reducer chain, which would be able to stream the tokens via the `async def on_llm_new_token` callback method Fixes # (issue) [5532](https://github.com/hwchase17/langchain/issues/5532) @hwchase17 @agola11 The following code snippet explains how this change would be used to enable `reduce_llm` with streaming support in a `map_reduce` chain I have tested this change and it works for the streaming use-case of reducer responses. I am happy to share more information if this makes solution sense. ``` AsyncHandler .......................... class StreamingLLMCallbackHandler(AsyncCallbackHandler): """Callback handler for streaming LLM responses.""" def __init__(self, websocket): self.websocket = websocket # This callback method is to be executed in async async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: resp = ChatResponse(sender="bot", message=token, type="stream") await self.websocket.send_json(resp.dict()) Chain .......... stream_handler = StreamingLLMCallbackHandler(websocket) stream_manager = AsyncCallbackManager([stream_handler]) streaming_llm = ChatOpenAI( streaming=True, callback_manager=stream_manager, verbose=False, temperature=0, ) main_llm = OpenAI( temperature=0, verbose=False, ) doc_chain = load_qa_chain( llm=main_llm, reduce_llm=streaming_llm, chain_type="map_reduce", callback_manager=manager ) qa_chain = ConversationalRetrievalChain( retriever=vectorstore.as_retriever(), combine_docs_chain=doc_chain, question_generator=question_generator, callback_manager=manager, ) # Here `acall` will trigger `acombine_docs` on `map_reduce` which should then call `_aprocess_result` which in turn will call `self.combine_document_chain.arun` hence async callback will be awaited result = await qa_chain.acall( {"question": question, "chat_history": chat_history} ) ```
This commit is contained in:
parent
e0dea577ee
commit
2b3b4e0f60
@ -163,16 +163,18 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
return self._process_results(results, docs, callbacks=callbacks, **kwargs)
|
return await self._aprocess_results(
|
||||||
|
results, docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def _process_results(
|
def _process_results_common(
|
||||||
self,
|
self,
|
||||||
results: List[Dict],
|
results: List[Dict],
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
token_max: int = 3000,
|
token_max: int = 3000,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[str, dict]:
|
) -> Tuple[List[Document], dict]:
|
||||||
question_result_key = self.llm_chain.output_key
|
question_result_key = self.llm_chain.output_key
|
||||||
result_docs = [
|
result_docs = [
|
||||||
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
||||||
@ -201,11 +203,39 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
extra_return_dict = {"intermediate_steps": _results}
|
extra_return_dict = {"intermediate_steps": _results}
|
||||||
else:
|
else:
|
||||||
extra_return_dict = {}
|
extra_return_dict = {}
|
||||||
|
return result_docs, extra_return_dict
|
||||||
|
|
||||||
|
def _process_results(
|
||||||
|
self,
|
||||||
|
results: List[Dict],
|
||||||
|
docs: List[Document],
|
||||||
|
token_max: int = 3000,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
result_docs, extra_return_dict = self._process_results_common(
|
||||||
|
results, docs, token_max, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
output = self.combine_document_chain.run(
|
output = self.combine_document_chain.run(
|
||||||
input_documents=result_docs, callbacks=callbacks, **kwargs
|
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||||
)
|
)
|
||||||
return output, extra_return_dict
|
return output, extra_return_dict
|
||||||
|
|
||||||
|
async def _aprocess_results(
|
||||||
|
self,
|
||||||
|
results: List[Dict],
|
||||||
|
docs: List[Document],
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Tuple[str, dict]:
|
||||||
|
result_docs, extra_return_dict = self._process_results_common(
|
||||||
|
results, docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
output = await self.combine_document_chain.arun(
|
||||||
|
input_documents=result_docs, callbacks=callbacks, **kwargs
|
||||||
|
)
|
||||||
|
return output, extra_return_dict
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _chain_type(self) -> str:
|
def _chain_type(self) -> str:
|
||||||
return "map_reduce_documents_chain"
|
return "map_reduce_documents_chain"
|
||||||
|
Loading…
Reference in New Issue
Block a user