propagate callbacks through load_summarize_chain (#7565)

This lets you pass callbacks when you create the summarize chain:

```
summarize = load_summarize_chain(llm, chain_type="map_reduce", callbacks=[my_callbacks])
summary = summarize(documents)
```
See #5572 for a similar surgical fix.

tagging @hwchase17 for callbacks work

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
pull/8738/head
Alec Flett 1 year ago committed by GitHub
parent 404d103c41
commit 5d765408ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
"""Load summarizing chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
@ -49,16 +50,22 @@ def _load_map_reduce_chain(
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
token_max: int = 3000,
callbacks: Callbacks = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
map_chain = LLMChain(
llm=llm, prompt=map_prompt, verbose=verbose, callbacks=callbacks
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(llm=_reduce_llm, prompt=combine_prompt, verbose=verbose)
reduce_chain = LLMChain(
llm=_reduce_llm, prompt=combine_prompt, verbose=verbose, callbacks=callbacks
)
# TODO: document prompt
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
verbose=verbose,
callbacks=callbacks,
)
if collapse_prompt is None:
collapse_chain = None
@ -74,6 +81,7 @@ def _load_map_reduce_chain(
llm=_collapse_llm,
prompt=collapse_prompt,
verbose=verbose,
callbacks=callbacks,
),
document_variable_name=combine_document_variable_name,
)
@ -82,12 +90,14 @@ def _load_map_reduce_chain(
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
callbacks=callbacks,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
reduce_documents_chain=reduce_documents_chain,
document_variable_name=map_reduce_document_variable_name,
verbose=verbose,
callbacks=callbacks,
**kwargs,
)

Loading…
Cancel
Save