propagate callbacks to ConversationalRetrievalChain (#5572)

# Allow callbacks to monitor ConversationalRetrievalChain

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

I ran into an issue where load_qa_chain was not passing the callbacks
down to the child LLM chains, and so made sure that callbacks are
propagated. There are probably more improvements to do here but this
seemed like a good place to stop.

Note that I saw a lot of references to callbacks_manager, which seems to
be deprecated. I left that code alone for now.



## Before submitting

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

## Who can review?

Community members can review the PR once tests pass. Tag
maintainers/contributors who might be interested:
@agola11
<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->
searx_updates
Alec Flett 11 months ago committed by GitHub
parent 3294774148
commit ec0dd6e34a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -204,6 +205,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
verbose: bool = False,
condense_question_llm: Optional[BaseLanguageModel] = None,
combine_docs_chain_kwargs: Optional[Dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
@ -212,17 +214,22 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
llm,
chain_type=chain_type,
verbose=verbose,
callbacks=callbacks,
**combine_docs_chain_kwargs,
)
_llm = condense_question_llm or llm
condense_question_chain = LLMChain(
llm=_llm, prompt=condense_question_prompt, verbose=verbose
llm=_llm,
prompt=condense_question_prompt,
verbose=verbose,
callbacks=callbacks,
)
return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
callbacks=callbacks,
**kwargs,
)
@ -264,6 +271,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
chain_type: str = "stuff",
combine_docs_chain_kwargs: Optional[Dict] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseConversationalRetrievalChain:
"""Load chain from LLM."""
@ -271,12 +279,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
doc_chain = load_qa_chain(
llm,
chain_type=chain_type,
callbacks=callbacks,
**combine_docs_chain_kwargs,
)
condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt)
condense_question_chain = LLMChain(
llm=llm, prompt=condense_question_prompt, callbacks=callbacks
)
return cls(
vectorstore=vectorstore,
combine_docs_chain=doc_chain,
question_generator=condense_question_chain,
callbacks=callbacks,
**kwargs,
)

@ -3,6 +3,7 @@ from typing import Any, Mapping, Optional, Protocol
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.map_rerank import MapRerankDocumentsChain
@ -35,10 +36,15 @@ def _load_map_rerank_chain(
rank_key: str = "score",
answer_key: str = "answer",
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> MapRerankDocumentsChain:
llm_chain = LLMChain(
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
llm=llm,
prompt=prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
return MapRerankDocumentsChain(
llm_chain=llm_chain,
@ -57,11 +63,16 @@ def _load_stuff_chain(
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm, prompt=_prompt, verbose=verbose, callback_manager=callback_manager
llm=llm,
prompt=_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
# TODO: document prompt
return StuffDocumentsChain(
@ -84,6 +95,7 @@ def _load_map_reduce_chain(
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
_question_prompt = (
@ -97,6 +109,7 @@ def _load_map_reduce_chain(
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(
@ -104,6 +117,7 @@ def _load_map_reduce_chain(
prompt=_combine_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
# TODO: document prompt
combine_document_chain = StuffDocumentsChain(
@ -111,6 +125,7 @@ def _load_map_reduce_chain(
document_variable_name=combine_document_variable_name,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
if collapse_prompt is None:
collapse_chain = None
@ -127,6 +142,7 @@ def _load_map_reduce_chain(
prompt=collapse_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
),
document_variable_name=combine_document_variable_name,
verbose=verbose,
@ -139,6 +155,7 @@ def _load_map_reduce_chain(
collapse_document_chain=collapse_chain,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
**kwargs,
)
@ -152,6 +169,7 @@ def _load_refine_chain(
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> RefineDocumentsChain:
_question_prompt = (
@ -165,6 +183,7 @@ def _load_refine_chain(
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(
@ -172,6 +191,7 @@ def _load_refine_chain(
prompt=_refine_prompt,
verbose=verbose,
callback_manager=callback_manager,
callbacks=callbacks,
)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,

Loading…
Cancel
Save