diff --git a/libs/core/langchain_core/runnables/config.py b/libs/core/langchain_core/runnables/config.py index c9e7904a5a..15f2006aab 100644 --- a/libs/core/langchain_core/runnables/config.py +++ b/libs/core/langchain_core/runnables/config.py @@ -305,12 +305,12 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: base["callbacks"] = mngr else: # base_callbacks is also a manager - base["callbacks"] = base_callbacks.__class__( + + manager = base_callbacks.__class__( parent_run_id=base_callbacks.parent_run_id or these_callbacks.parent_run_id, - handlers=base_callbacks.handlers + these_callbacks.handlers, - inheritable_handlers=base_callbacks.inheritable_handlers - + these_callbacks.inheritable_handlers, + handlers=[], + inheritable_handlers=[], tags=list(set(base_callbacks.tags + these_callbacks.tags)), inheritable_tags=list( set( @@ -323,6 +323,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig: **these_callbacks.metadata, }, ) + + handlers = base_callbacks.handlers + these_callbacks.handlers + inheritable_handlers = ( + base_callbacks.inheritable_handlers + + these_callbacks.inheritable_handlers + ) + + for handler in handlers: + manager.add_handler(handler) + + for handler in inheritable_handlers: + manager.add_handler(handler, inherit=True) + + base["callbacks"] = manager else: base[key] = config[key] or base.get(key) # type: ignore return base diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 88d19f3372..d03bd195c4 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -1876,3 +1876,34 @@ async def test_runnable_generator() -> None: "tags": [], }, ] + + +async def test_with_explicit_config() -> None: + """Test astream events with explicit callbacks being passed.""" + infinite_cycle = cycle([AIMessage(content="hello world", id="ai3")]) + model = GenericFakeChatModel(messages=infinite_cycle) + + @tool + async def say_hello(query: str, callbacks: Callbacks) -> BaseMessage: + """Use this tool to look up which items are in the given place.""" + + @RunnableLambda + def passthrough_to_trigger_issue(x: str) -> str: + """Add passthrough to trigger issue.""" + return x + + chain = passthrough_to_trigger_issue | model.with_config( + {"tags": ["hello"], "callbacks": callbacks} + ) + + return await chain.ainvoke(query) + + events = await _collect_events( + say_hello.astream_events("meow", version="v2") # type: ignore + ) + + assert [ + event["data"]["chunk"].content + for event in events + if event["event"] == "on_chat_model_stream" + ] == ["hello", " ", "world"]