core[patch]: Deduplicate of callback handlers in merge_configs (#22478)

This PR adds deduplication of callback handlers in merge_configs.

Fix for this issue:
https://github.com/langchain-ai/langchain/issues/22227

The issue appears when the code is:

1) running python >=3.11
2) invokes a runnable from within a runnable
3) binds the callbacks to the child runnable from the parent runnable
using with_config

In this case, the same callbacks end up appearing twice: (1) the first
time from with_config, (2) the second time with langchain automatically
propagating them on behalf of the user.


Prior to this PR this will emit duplicate events:

```python
@tool
async def get_items(question: str, callbacks: Callbacks):  # <--- Accept callbacks
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model.with_config(
        {
            "callbacks": callbacks,  # <-- Propagate callbacks
        }
    )
    return await chain.ainvoke({"question": question})
```

Prior to this PR this will work work correctly (no duplicate events):

```python
@tool
async def get_items(question: str, callbacks: Callbacks):  # <--- Accept callbacks
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model
    return await chain.ainvoke({"question": question}, {"callbacks": callbacks})
```

This will also work (as long as the user is using python >= 3.11) -- as
langchain will automatically propagate callbacks

```python
@tool
async def get_items(question: str,):  
    """Ask question"""
    template = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                "'{question}"
            )
        ]
    )
    chain = template | chat_model
    return await chain.ainvoke({"question": question})
```
This commit is contained in:
Eugene Yurtsev 2024-06-04 16:19:00 -04:00 committed by GitHub
parent 64dbc52cae
commit 9120cf5df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 4 deletions

View File

@ -305,12 +305,12 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
base["callbacks"] = mngr base["callbacks"] = mngr
else: else:
# base_callbacks is also a manager # base_callbacks is also a manager
base["callbacks"] = base_callbacks.__class__(
manager = base_callbacks.__class__(
parent_run_id=base_callbacks.parent_run_id parent_run_id=base_callbacks.parent_run_id
or these_callbacks.parent_run_id, or these_callbacks.parent_run_id,
handlers=base_callbacks.handlers + these_callbacks.handlers, handlers=[],
inheritable_handlers=base_callbacks.inheritable_handlers inheritable_handlers=[],
+ these_callbacks.inheritable_handlers,
tags=list(set(base_callbacks.tags + these_callbacks.tags)), tags=list(set(base_callbacks.tags + these_callbacks.tags)),
inheritable_tags=list( inheritable_tags=list(
set( set(
@ -323,6 +323,20 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
**these_callbacks.metadata, **these_callbacks.metadata,
}, },
) )
handlers = base_callbacks.handlers + these_callbacks.handlers
inheritable_handlers = (
base_callbacks.inheritable_handlers
+ these_callbacks.inheritable_handlers
)
for handler in handlers:
manager.add_handler(handler)
for handler in inheritable_handlers:
manager.add_handler(handler, inherit=True)
base["callbacks"] = manager
else: else:
base[key] = config[key] or base.get(key) # type: ignore base[key] = config[key] or base.get(key) # type: ignore
return base return base

View File

@ -1876,3 +1876,34 @@ async def test_runnable_generator() -> None:
"tags": [], "tags": [],
}, },
] ]
async def test_with_explicit_config() -> None:
"""Test astream events with explicit callbacks being passed."""
infinite_cycle = cycle([AIMessage(content="hello world", id="ai3")])
model = GenericFakeChatModel(messages=infinite_cycle)
@tool
async def say_hello(query: str, callbacks: Callbacks) -> BaseMessage:
"""Use this tool to look up which items are in the given place."""
@RunnableLambda
def passthrough_to_trigger_issue(x: str) -> str:
"""Add passthrough to trigger issue."""
return x
chain = passthrough_to_trigger_issue | model.with_config(
{"tags": ["hello"], "callbacks": callbacks}
)
return await chain.ainvoke(query)
events = await _collect_events(
say_hello.astream_events("meow", version="v2") # type: ignore
)
assert [
event["data"]["chunk"].content
for event in events
if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"]