diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 3d9ef22ce8..e8385d6fe7 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -11,6 +11,7 @@ from pydantic import Extra, Field, root_validator from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, + Callbacks, ) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -65,11 +66,12 @@ class BaseRetrievalQA(Chain): cls, llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseRetrievalQA: """Initialize from LLM.""" _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) - llm_chain = LLMChain(llm=llm, prompt=_prompt) + llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" ) @@ -77,9 +79,14 @@ class BaseRetrievalQA(Chain): llm_chain=llm_chain, document_variable_name="context", document_prompt=document_prompt, + callbacks=callbacks, ) - return cls(combine_documents_chain=combine_documents_chain, **kwargs) + return cls( + combine_documents_chain=combine_documents_chain, + callbacks=callbacks, + **kwargs, + ) @classmethod def from_chain_type(