From 404d103c41e6ce208a64cf4c985a592fc40465a8 Mon Sep 17 00:00:00 2001 From: Alec Flett Date: Thu, 3 Aug 2023 20:11:58 -0700 Subject: [PATCH] propagate RetrievalQA chain callbacks through its own LLMChain and StuffDocumentsChain (#7853) This is another case, similar to #5572 and #7565 where the callbacks are getting dropped during construction of the chains. tagging @hwchase17 and @agola11 for callbacks propagation --- libs/langchain/langchain/chains/retrieval_qa/base.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 3d9ef22ce8..e8385d6fe7 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -11,6 +11,7 @@ from pydantic import Extra, Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, + Callbacks, ) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -65,11 +66,12 @@ class BaseRetrievalQA(Chain): cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseRetrievalQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) - llm_chain = LLMChain(llm=llm, prompt=_prompt) + llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" ) @@ -77,9 +79,14 @@ class BaseRetrievalQA(Chain): llm_chain=llm_chain, document_variable_name="context", document_prompt=document_prompt, + callbacks=callbacks, ) - return cls(combine_documents_chain=combine_documents_chain, **kwargs) + return cls( + combine_documents_chain=combine_documents_chain, + callbacks=callbacks, + **kwargs, + ) @classmethod def from_chain_type(