diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 06e87e03..84e49296 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -163,16 +163,18 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs], callbacks=callbacks, ) - return self._process_results(results, docs, callbacks=callbacks, **kwargs) + return await self._aprocess_results( + results, docs, callbacks=callbacks, **kwargs + ) - def _process_results( + def _process_results_common( self, results: List[Dict], docs: List[Document], token_max: int = 3000, callbacks: Callbacks = None, **kwargs: Any, - ) -> Tuple[str, dict]: + ) -> Tuple[List[Document], dict]: question_result_key = self.llm_chain.output_key result_docs = [ Document(page_content=r[question_result_key], metadata=docs[i].metadata) @@ -201,11 +203,39 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): extra_return_dict = {"intermediate_steps": _results} else: extra_return_dict = {} + return result_docs, extra_return_dict + + def _process_results( + self, + results: List[Dict], + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[str, dict]: + result_docs, extra_return_dict = self._process_results_common( + results, docs, token_max, callbacks=callbacks, **kwargs + ) output = self.combine_document_chain.run( input_documents=result_docs, callbacks=callbacks, **kwargs ) return output, extra_return_dict + async def _aprocess_results( + self, + results: List[Dict], + docs: List[Document], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Tuple[str, dict]: + result_docs, extra_return_dict = self._process_results_common( + results, docs, callbacks=callbacks, **kwargs + ) + output = await self.combine_document_chain.arun( + input_documents=result_docs, callbacks=callbacks, **kwargs + ) + return output, extra_return_dict + @property def _chain_type(self) -> str: return "map_reduce_documents_chain"