propagate RetrievalQA chain callbacks through its own LLMChain and StuffDocumentsChain (#7853)

This is another case, similar to #5572 and #7565 where the callbacks are
getting dropped during construction of the chains.

tagging @hwchase17 and @agola11 for callbacks propagation

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
This commit is contained in:
Alec Flett 2023-08-03 20:11:58 -07:00 committed by GitHub
parent 47eea32f6a
commit 404d103c41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,6 +11,7 @@ from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -65,11 +66,12 @@ class BaseRetrievalQA(Chain):
cls,
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseRetrievalQA:
"""Initialize from LLM."""
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt)
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)
@ -77,9 +79,14 @@ class BaseRetrievalQA(Chain):
llm_chain=llm_chain,
document_variable_name="context",
document_prompt=document_prompt,
callbacks=callbacks,
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
return cls(
combine_documents_chain=combine_documents_chain,
callbacks=callbacks,
**kwargs,
)
@classmethod
def from_chain_type(