From ec0dd6e34a9e102468990cf45f6b01032b588853 Mon Sep 17 00:00:00 2001 From: Alec Flett Date: Wed, 7 Jun 2023 20:25:21 -0700 Subject: [PATCH] propagate callbacks to ConversationalRetrievalChain (#5572) # Allow callbacks to monitor ConversationalRetrievalChain I ran into an issue where load_qa_chain was not passing the callbacks down to the child LLM chains, and so made sure that callbacks are propagated. There are probably more improvements to do here but this seemed like a good place to stop. Note that I saw a lot of references to callbacks_manager, which seems to be deprecated. I left that code alone for now. ## Before submitting ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @agola11 --- .../chains/conversational_retrieval/base.py | 16 +++++++++++-- .../chains/question_answering/__init__.py | 24 +++++++++++++++++-- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index baa3f6dd..396f2092 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -12,6 +12,7 @@ from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, + Callbacks, ) from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -204,6 +205,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): verbose: bool = False, condense_question_llm: Optional[BaseLanguageModel] = None, combine_docs_chain_kwargs: Optional[Dict] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" @@ -212,17 +214,22 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): llm, chain_type=chain_type, verbose=verbose, + callbacks=callbacks, **combine_docs_chain_kwargs, ) _llm = condense_question_llm or llm condense_question_chain = LLMChain( - llm=_llm, prompt=condense_question_prompt, verbose=verbose + llm=_llm, + prompt=condense_question_prompt, + verbose=verbose, + callbacks=callbacks, ) return cls( retriever=retriever, combine_docs_chain=doc_chain, question_generator=condense_question_chain, + callbacks=callbacks, **kwargs, ) @@ -264,6 +271,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, chain_type: str = "stuff", combine_docs_chain_kwargs: Optional[Dict] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" @@ -271,12 +279,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): doc_chain = load_qa_chain( llm, chain_type=chain_type, + callbacks=callbacks, **combine_docs_chain_kwargs, ) - condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) + condense_question_chain = LLMChain( + llm=llm, prompt=condense_question_prompt, callbacks=callbacks + ) return cls( vectorstore=vectorstore, combine_docs_chain=doc_chain, question_generator=condense_question_chain, + callbacks=callbacks, **kwargs, ) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 95c24f0a..107415cd 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain @@ -35,10 +36,15 @@ def _load_map_rerank_chain( rank_key: str = "score", answer_key: str = "answer", callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> MapRerankDocumentsChain: llm_chain = LLMChain( - llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager + llm=llm, + prompt=prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, ) return MapRerankDocumentsChain( llm_chain=llm_chain, @@ -57,11 +63,16 @@ def _load_stuff_chain( document_variable_name: str = "context", verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> StuffDocumentsChain: _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( - llm=llm, prompt=_prompt, verbose=verbose, callback_manager=callback_manager + llm=llm, + prompt=_prompt, + verbose=verbose, + callback_manager=callback_manager, + callbacks=callbacks, ) # TODO: document prompt return StuffDocumentsChain( @@ -84,6 +95,7 @@ def _load_map_reduce_chain( collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> MapReduceDocumentsChain: _question_prompt = ( @@ -97,6 +109,7 @@ def _load_map_reduce_chain( prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) _reduce_llm = reduce_llm or llm reduce_chain = LLMChain( @@ -104,6 +117,7 @@ def _load_map_reduce_chain( prompt=_combine_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) # TODO: document prompt combine_document_chain = StuffDocumentsChain( @@ -111,6 +125,7 @@ def _load_map_reduce_chain( document_variable_name=combine_document_variable_name, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) if collapse_prompt is None: collapse_chain = None @@ -127,6 +142,7 @@ def _load_map_reduce_chain( prompt=collapse_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ), document_variable_name=combine_document_variable_name, verbose=verbose, @@ -139,6 +155,7 @@ def _load_map_reduce_chain( collapse_document_chain=collapse_chain, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, **kwargs, ) @@ -152,6 +169,7 @@ def _load_refine_chain( refine_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, + callbacks: Callbacks = None, **kwargs: Any, ) -> RefineDocumentsChain: _question_prompt = ( @@ -165,6 +183,7 @@ def _load_refine_chain( prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) _refine_llm = refine_llm or llm refine_chain = LLMChain( @@ -172,6 +191,7 @@ def _load_refine_chain( prompt=_refine_prompt, verbose=verbose, callback_manager=callback_manager, + callbacks=callbacks, ) return RefineDocumentsChain( initial_llm_chain=initial_chain,