propogate token max (#7201)

This commit is contained in:
Harrison Chase 2023-07-05 10:25:48 -04:00 committed by GitHub
parent a94c4cca68
commit 1415966d64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 4 deletions

View File

@ -219,7 +219,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
for i, r in enumerate(map_results) for i, r in enumerate(map_results)
] ]
result, extra_return_dict = self.reduce_documents_chain.combine_docs( result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, callbacks=callbacks, **kwargs result_docs, token_max=token_max, callbacks=callbacks, **kwargs
) )
if self.return_intermediate_steps: if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results] intermediate_steps = [r[question_result_key] for r in map_results]
@ -227,7 +227,11 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return result, extra_return_dict return result, extra_return_dict
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine documents in a map reduce manner. """Combine documents in a map reduce manner.
@ -246,7 +250,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
for i, r in enumerate(map_results) for i, r in enumerate(map_results)
] ]
result, extra_return_dict = await self.reduce_documents_chain.acombine_docs( result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
result_docs, callbacks=callbacks, **kwargs result_docs, token_max=token_max, callbacks=callbacks, **kwargs
) )
if self.return_intermediate_steps: if self.return_intermediate_steps:
intermediate_steps = [r[question_result_key] for r in map_results] intermediate_steps = [r[question_result_key] for r in map_results]

View File

@ -196,7 +196,11 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
) )
async def acombine_docs( async def acombine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any self,
docs: List[Document],
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]: ) -> Tuple[str, dict]:
"""Combine multiple documents recursively. """Combine multiple documents recursively.