mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
propagate RetrievalQA chain callbacks through its own LLMChain and StuffDocumentsChain (#7853)
This is another case, similar to #5572 and #7565 where the callbacks are getting dropped during construction of the chains. tagging @hwchase17 and @agola11 for callbacks propagation <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
47eea32f6a
commit
404d103c41
@ -11,6 +11,7 @@ from pydantic import Extra, Field, root_validator
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
@ -65,11 +66,12 @@ class BaseRetrievalQA(Chain):
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseRetrievalQA:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
@ -77,9 +79,14 @@ class BaseRetrievalQA(Chain):
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name="context",
|
||||
document_prompt=document_prompt,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
callbacks=callbacks,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_chain_type(
|
||||
|
Loading…
Reference in New Issue
Block a user