core[patch]: runnable config ensure_config deep copy from var_child_runnable… (#24862)

**issue**: #24660 
RunnableWithMessageHistory.stream result in error because the
[evaluation](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/runnables/branch.py#L220)
of the branch
[condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1))
unexpectedly trigger the
"[on_end](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332))"
(exit_history) callback of the default branch


**descriptions**
After a lot of investigation I'm convinced that the root cause is that
1. during the execution of the runnable, the
[var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L122))
is shared between the branch
[condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1))
runnable and the [default branch
runnable](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332))
within the same context
2. when the default branch runnable runs, it gets the
[var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L163))
and may unintentionally [add more handlers
](99eb31ec41/libs/core/langchain_core/runnables/config.py (L325))to
the callback manager of this config
3. when it is again the turn for the
[condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1))
to run, it gets the `var_child_runnable_config` whose callback manager
has the handlers added by the default branch. When it runs that handler
(`exit_history`) it leads to the error
   
with the assumption that, the `ensure_config` function actually does
want to create a immutable copy from `var_child_runnable_config` because
it starts with an [`empty` variable
](99eb31ec41/libs/core/langchain_core/runnables/config.py (L156)),
i go ahead to do a deepcopy to ensure that future modification to the
returned value won't affect the `var_child_runnable_config` variable
   
   Having said that I actually 
1. don't know if this is a proper fix
2. don't know whether it will lead to other unintended consequence 
3. don't know why only "stream" runs into this issue while "invoke" runs
without problem

so @nfcampos @hwchase17 please help review, thanks!

---------

Co-authored-by: Lifu Wu <lifu@nextbillion.ai>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
WU LIFU 2024-08-02 08:30:32 +08:00 committed by GitHub
parent 3ab09d87d6
commit ad16eed119
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 98 additions and 12 deletions

View File

@ -116,6 +116,13 @@ CONFIG_KEYS = [
"run_id",
]
COPIABLE_KEYS = [
"tags",
"metadata",
"callbacks",
"configurable",
]
DEFAULT_RECURSION_LIMIT = 25
@ -162,15 +169,30 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
)
if var_config := var_child_runnable_config.get():
empty.update(
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
cast(
RunnableConfig,
{
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
for k, v in var_config.items()
if v is not None
},
)
)
if config is not None:
empty.update(
cast(
RunnableConfig,
{
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
for k, v in config.items()
if v is not None and k in CONFIG_KEYS
},
)
)
if config is not None:
for k, v in config.items():
if v is not None:
if k in CONFIG_KEYS:
empty[k] = v # type: ignore[literal-required]
else:
empty["configurable"][k] = v
if k not in CONFIG_KEYS and v is not None:
empty["configurable"][k] = v
for key, value in empty.get("configurable", {}).items():
if (
not key.startswith("__")
@ -291,7 +313,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
**(config.get(key) or {}), # type: ignore
}
elif key == "tags":
base[key] = list( # type: ignore
base[key] = sorted( # type: ignore
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
)
elif key == "configurable":
@ -306,7 +328,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
# so merging two callbacks values has 6 cases
if isinstance(these_callbacks, list):
if base_callbacks is None:
base["callbacks"] = these_callbacks
base["callbacks"] = these_callbacks.copy()
elif isinstance(base_callbacks, list):
base["callbacks"] = base_callbacks + these_callbacks
else:
@ -318,7 +340,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
elif these_callbacks is not None:
# these_callbacks is a manager
if base_callbacks is None:
base["callbacks"] = these_callbacks
base["callbacks"] = these_callbacks.copy()
elif isinstance(base_callbacks, list):
mngr = these_callbacks.copy()
for callback in base_callbacks:
@ -361,6 +383,8 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
elif key == "recursion_limit":
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
base["recursion_limit"] = config["recursion_limit"]
elif key in COPIABLE_KEYS and config[key] is not None: # type: ignore[literal-required]
base[key] = config[key].copy() # type: ignore[literal-required]
else:
base[key] = config[key] or base.get(key) # type: ignore
return base

View File

@ -1,4 +1,7 @@
from typing import Any, cast
import json
import uuid
from contextvars import copy_context
from typing import Any, Dict, cast
import pytest
@ -8,12 +11,61 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
ensure_config,
merge_configs,
run_in_executor,
)
from langchain_core.tracers.stdout import ConsoleCallbackHandler
def test_ensure_config() -> None:
run_id = str(uuid.uuid4())
arg: Dict = {
"something": "else",
"metadata": {"foo": "bar"},
"configurable": {"baz": "qux"},
"callbacks": [StdOutCallbackHandler()],
"tags": ["tag1", "tag2"],
"max_concurrency": 1,
"recursion_limit": 100,
"run_id": run_id,
"run_name": "test",
}
arg_str = json.dumps({**arg, "callbacks": []})
ctx = copy_context()
ctx.run(
_set_config_context,
{
"callbacks": [ConsoleCallbackHandler()],
"metadata": {"a": "b"},
"configurable": {"c": "d"},
"tags": ["tag3", "tag4"],
},
)
config = ctx.run(ensure_config, cast(RunnableConfig, arg))
assert (
len(arg["callbacks"]) == 1
), "ensure_config should not modify the original config"
assert (
json.dumps({**arg, "callbacks": []}) == arg_str
), "ensure_config should not modify the original config"
assert config is not arg
assert config["callbacks"] is not arg["callbacks"]
assert config["metadata"] is not arg["metadata"]
assert config["configurable"] is not arg["configurable"]
assert config == {
"tags": ["tag1", "tag2"],
"metadata": {"foo": "bar", "baz": "qux", "something": "else"},
"callbacks": [arg["callbacks"][0]],
"recursion_limit": 100,
"configurable": {"baz": "qux", "something": "else"},
"max_concurrency": 1,
"run_id": run_id,
"run_name": "test",
}
def test_merge_config_callbacks() -> None:
manager: RunnableConfig = {
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])

View File

@ -53,6 +53,8 @@ def test_input_messages() -> None:
assert output == "you said: hello"
output = with_history.invoke([HumanMessage(content="good bye")], config)
assert output == "you said: hello\ngood bye"
output = [*with_history.stream([HumanMessage(content="hi again")], config)]
assert output == ["you said: hello\ngood bye\nhi again"]
assert store == {
"1": InMemoryChatMessageHistory(
messages=[
@ -60,6 +62,8 @@ def test_input_messages() -> None:
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
HumanMessage(content="hi again"),
AIMessage(content="you said: hello\ngood bye\nhi again"),
]
)
}
@ -78,6 +82,10 @@ async def test_input_messages_async() -> None:
assert output == "you said: hello"
output = await with_history.ainvoke([HumanMessage(content="good bye")], config) # type: ignore[arg-type]
assert output == "you said: hello\ngood bye"
output = [
c
async for c in with_history.astream([HumanMessage(content="hi again")], config) # type: ignore[arg-type]
] == ["you said: hello\ngood bye\nhi again"]
assert store == {
"1_async": InMemoryChatMessageHistory(
messages=[
@ -85,6 +93,8 @@ async def test_input_messages_async() -> None:
AIMessage(content="you said: hello"),
HumanMessage(content="good bye"),
AIMessage(content="you said: hello\ngood bye"),
HumanMessage(content="hi again"),
AIMessage(content="you said: hello\ngood bye\nhi again"),
]
)
}

View File

@ -5402,7 +5402,7 @@ def test_listeners() -> None:
shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}
@ -5432,7 +5432,7 @@ async def test_listeners_async() -> None:
shared_state = {}
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
def on_start(run: Run) -> None:
shared_state[run.id] = {"inputs": run.inputs}