diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 6fc30415eb..e4dbec3e85 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -219,7 +219,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): for i, r in enumerate(map_results) ] result, extra_return_dict = self.reduce_documents_chain.combine_docs( - result_docs, callbacks=callbacks, **kwargs + result_docs, token_max=token_max, callbacks=callbacks, **kwargs ) if self.return_intermediate_steps: intermediate_steps = [r[question_result_key] for r in map_results] @@ -227,7 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return result, extra_return_dict async def acombine_docs( - self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Tuple[str, dict]: """Combine documents in a map reduce manner. @@ -246,7 +250,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): for i, r in enumerate(map_results) ] result, extra_return_dict = await self.reduce_documents_chain.acombine_docs( - result_docs, callbacks=callbacks, **kwargs + result_docs, token_max=token_max, callbacks=callbacks, **kwargs ) if self.return_intermediate_steps: intermediate_steps = [r[question_result_key] for r in map_results] diff --git a/langchain/chains/combine_documents/reduce.py b/langchain/chains/combine_documents/reduce.py index 1b7e632dde..9458c39491 100644 --- a/langchain/chains/combine_documents/reduce.py +++ b/langchain/chains/combine_documents/reduce.py @@ -196,7 +196,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain): ) async def acombine_docs( - self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any + self, + docs: List[Document], + token_max: int = 3000, + callbacks: Callbacks = None, + **kwargs: Any, ) -> Tuple[str, dict]: """Combine multiple documents recursively.