diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index baa3f6dd..396f2092 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -12,6 +12,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, + Callbacks, ) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -204,6 +205,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): verbose: bool = False, condense_question_llm: Optional[BaseLanguageModel] = None, combine_docs_chain_kwargs: Optional[Dict] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" @@ -212,17 +214,22 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): llm, chain_type=chain_type, verbose=verbose, + callbacks=callbacks, **combine_docs_chain_kwargs, ) _llm = condense_question_llm or llm condense_question_chain = LLMChain( - llm=_llm, prompt=condense_question_prompt, verbose=verbose + llm=_llm, + prompt=condense_question_prompt, + verbose=verbose, + callbacks=callbacks, ) return cls( retriever=retriever, combine_docs_chain=doc_chain, question_generator=condense_question_chain, + callbacks=callbacks, **kwargs, ) @@ -264,6 +271,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, chain_type: str = "stuff", combine_docs_chain_kwargs: Optional[Dict] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" @@ -271,12 +279,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): doc_chain = load_qa_chain( llm, chain_type=chain_type, + callbacks=callbacks, **combine_docs_chain_kwargs, ) - condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) + condense_question_chain = LLMChain( + llm=llm, prompt=condense_question_prompt, callbacks=callbacks + ) return cls( vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=condense_question_chain, + callbacks=callbacks, **kwargs, ) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 95c24f0a..107415cd 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -35,10 +36,15 @@ def _load_map_rerank_chain( rank_key: str = "score", answer_key: str = "answer", callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> MapRerankDocumentsChain: llm_chain = LLMChain( - llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager + llm=llm, + prompt=prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, ) return MapRerankDocumentsChain( llm_chain=llm_chain, @@ -57,11 +63,16 @@ def _load_stuff_chain( document_variable_name: str = "context", verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> StuffDocumentsChain: _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( - llm=llm, prompt=_prompt, verbose=verbose, callback_manager=callback_manager + llm=llm, + prompt=_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, ) # TODO: document prompt return StuffDocumentsChain( @@ -84,6 +95,7 @@ def _load_map_reduce_chain( collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> MapReduceDocumentsChain: _question_prompt = ( @@ -97,6 +109,7 @@ def _load_map_reduce_chain( prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) _reduce_llm = reduce_llm or llm reduce_chain = LLMChain( @@ -104,6 +117,7 @@ def _load_map_reduce_chain( prompt=_combine_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) # TODO: document prompt combine_document_chain = StuffDocumentsChain( @@ -111,6 +125,7 @@ def _load_map_reduce_chain( document_variable_name=combine_document_variable_name, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) if collapse_prompt is None: collapse_chain = None @@ -127,6 +142,7 @@ def _load_map_reduce_chain( prompt=collapse_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ), document_variable_name=combine_document_variable_name, verbose=verbose, @@ -139,6 +155,7 @@ def _load_map_reduce_chain( collapse_document_chain=collapse_chain, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, **kwargs, ) @@ -152,6 +169,7 @@ def _load_refine_chain( refine_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> RefineDocumentsChain: _question_prompt = ( @@ -165,6 +183,7 @@ def _load_refine_chain( prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) _refine_llm = refine_llm or llm refine_chain = LLMChain( @@ -172,6 +191,7 @@ def _load_refine_chain( prompt=_refine_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) return RefineDocumentsChain( initial_llm_chain=initial_chain,