mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
core[patch]: Deduplicate of callback handlers in merge_configs (#22478)
This PR adds deduplication of callback handlers in merge_configs. Fix for this issue: https://github.com/langchain-ai/langchain/issues/22227 The issue appears when the code is: 1) running python >=3.11 2) invokes a runnable from within a runnable 3) binds the callbacks to the child runnable from the parent runnable using with_config In this case, the same callbacks end up appearing twice: (1) the first time from with_config, (2) the second time with langchain automatically propagating them on behalf of the user. Prior to this PR this will emit duplicate events: ```python @tool async def get_items(question: str, callbacks: Callbacks): # <--- Accept callbacks """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model.with_config( { "callbacks": callbacks, # <-- Propagate callbacks } ) return await chain.ainvoke({"question": question}) ``` Prior to this PR this will work work correctly (no duplicate events): ```python @tool async def get_items(question: str, callbacks: Callbacks): # <--- Accept callbacks """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model return await chain.ainvoke({"question": question}, {"callbacks": callbacks}) ``` This will also work (as long as the user is using python >= 3.11) -- as langchain will automatically propagate callbacks ```python @tool async def get_items(question: str,): """Ask question""" template = ChatPromptTemplate.from_messages( [ ( "human", "'{question}" ) ] ) chain = template | chat_model return await chain.ainvoke({"question": question}) ```
This commit is contained in:
parent
64dbc52cae
commit
9120cf5df2
@ -305,12 +305,12 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
base["callbacks"] = mngr
|
base["callbacks"] = mngr
|
||||||
else:
|
else:
|
||||||
# base_callbacks is also a manager
|
# base_callbacks is also a manager
|
||||||
base["callbacks"] = base_callbacks.__class__(
|
|
||||||
|
manager = base_callbacks.__class__(
|
||||||
parent_run_id=base_callbacks.parent_run_id
|
parent_run_id=base_callbacks.parent_run_id
|
||||||
or these_callbacks.parent_run_id,
|
or these_callbacks.parent_run_id,
|
||||||
handlers=base_callbacks.handlers + these_callbacks.handlers,
|
handlers=[],
|
||||||
inheritable_handlers=base_callbacks.inheritable_handlers
|
inheritable_handlers=[],
|
||||||
+ these_callbacks.inheritable_handlers,
|
|
||||||
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
tags=list(set(base_callbacks.tags + these_callbacks.tags)),
|
||||||
inheritable_tags=list(
|
inheritable_tags=list(
|
||||||
set(
|
set(
|
||||||
@ -323,6 +323,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
|||||||
**these_callbacks.metadata,
|
**these_callbacks.metadata,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
handlers = base_callbacks.handlers + these_callbacks.handlers
|
||||||
|
inheritable_handlers = (
|
||||||
|
base_callbacks.inheritable_handlers
|
||||||
|
+ these_callbacks.inheritable_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
for handler in handlers:
|
||||||
|
manager.add_handler(handler)
|
||||||
|
|
||||||
|
for handler in inheritable_handlers:
|
||||||
|
manager.add_handler(handler, inherit=True)
|
||||||
|
|
||||||
|
base["callbacks"] = manager
|
||||||
else:
|
else:
|
||||||
base[key] = config[key] or base.get(key) # type: ignore
|
base[key] = config[key] or base.get(key) # type: ignore
|
||||||
return base
|
return base
|
||||||
|
@ -1876,3 +1876,34 @@ async def test_runnable_generator() -> None:
|
|||||||
"tags": [],
|
"tags": [],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_with_explicit_config() -> None:
|
||||||
|
"""Test astream events with explicit callbacks being passed."""
|
||||||
|
infinite_cycle = cycle([AIMessage(content="hello world", id="ai3")])
|
||||||
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def say_hello(query: str, callbacks: Callbacks) -> BaseMessage:
|
||||||
|
"""Use this tool to look up which items are in the given place."""
|
||||||
|
|
||||||
|
@RunnableLambda
|
||||||
|
def passthrough_to_trigger_issue(x: str) -> str:
|
||||||
|
"""Add passthrough to trigger issue."""
|
||||||
|
return x
|
||||||
|
|
||||||
|
chain = passthrough_to_trigger_issue | model.with_config(
|
||||||
|
{"tags": ["hello"], "callbacks": callbacks}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await chain.ainvoke(query)
|
||||||
|
|
||||||
|
events = await _collect_events(
|
||||||
|
say_hello.astream_events("meow", version="v2") # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [
|
||||||
|
event["data"]["chunk"].content
|
||||||
|
for event in events
|
||||||
|
if event["event"] == "on_chat_model_stream"
|
||||||
|
] == ["hello", " ", "world"]
|
||||||
|
Loading…
Reference in New Issue
Block a user