diff --git a/libs/langchain/langchain/chains/summarize/__init__.py b/libs/langchain/langchain/chains/summarize/__init__.py index 96d6279302..681019b107 100644 --- a/libs/langchain/langchain/chains/summarize/__init__.py +++ b/libs/langchain/langchain/chains/summarize/__init__.py @@ -1,6 +1,7 @@ """Load summarizing chains.""" from typing import Any, Mapping, Optional, Protocol +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain @@ -49,16 +50,22 @@ def _load_map_reduce_chain( collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, token_max: int = 3000, + callbacks: Callbacks = None, **kwargs: Any, ) -> MapReduceDocumentsChain: - map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose) + map_chain = LLMChain( + llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks + ) _reduce_llm = reduce_llm or llm - reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose) + reduce_chain = LLMChain( + llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks + ) # TODO: document prompt combine_documents_chain = StuffDocumentsChain( llm_chain=reduce_chain, document_variable_name=combine_document_variable_name, verbose=verbose, + callbacks=callbacks, ) if collapse_prompt is None: collapse_chain = None @@ -74,6 +81,7 @@ def _load_map_reduce_chain( llm=_collapse_llm, prompt=collapse_prompt, verbose=verbose, + callbacks=callbacks, ), document_variable_name=combine_document_variable_name, ) @@ -82,12 +90,14 @@ def _load_map_reduce_chain( collapse_documents_chain=collapse_chain, token_max=token_max, verbose=verbose, + callbacks=callbacks, ) return MapReduceDocumentsChain( llm_chain=map_chain, reduce_documents_chain=reduce_documents_chain, document_variable_name=map_reduce_document_variable_name, verbose=verbose, + callbacks=callbacks, **kwargs, )