mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: runnable config ensure_config deep copy from var_child_runnable… (#24862)
**issue**: #24660 RunnableWithMessageHistory.stream result in error because the [evaluation](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/runnables/branch.py#L220) of the branch [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) unexpectedly trigger the "[on_end](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332)
)" (exit_history) callback of the default branch **descriptions** After a lot of investigation I'm convinced that the root cause is that 1. during the execution of the runnable, the [var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L122)
) is shared between the branch [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) runnable and the [default branch runnable](99eb31ec41/libs/core/langchain_core/runnables/history.py (L332)
) within the same context 2. when the default branch runnable runs, it gets the [var_child_runnable_config](99eb31ec41/libs/core/langchain_core/runnables/config.py (L163)
) and may unintentionally [add more handlers ](99eb31ec41/libs/core/langchain_core/runnables/config.py (L325)
)to the callback manager of this config 3. when it is again the turn for the [condition](99eb31ec41/libs/core/langchain_core/runnables/history.py (L328C1-L329C1)
) to run, it gets the `var_child_runnable_config` whose callback manager has the handlers added by the default branch. When it runs that handler (`exit_history`) it leads to the error with the assumption that, the `ensure_config` function actually does want to create a immutable copy from `var_child_runnable_config` because it starts with an [`empty` variable ](99eb31ec41/libs/core/langchain_core/runnables/config.py (L156)
), i go ahead to do a deepcopy to ensure that future modification to the returned value won't affect the `var_child_runnable_config` variable Having said that I actually 1. don't know if this is a proper fix 2. don't know whether it will lead to other unintended consequence 3. don't know why only "stream" runs into this issue while "invoke" runs without problem so @nfcampos @hwchase17 please help review, thanks! --------- Co-authored-by: Lifu Wu <lifu@nextbillion.ai> Co-authored-by: Nuno Campos <nuno@langchain.dev> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
3ab09d87d6
commit
ad16eed119
@ -116,6 +116,13 @@ CONFIG_KEYS = [
|
||||
"run_id",
|
||||
]
|
||||
|
||||
COPIABLE_KEYS = [
|
||||
"tags",
|
||||
"metadata",
|
||||
"callbacks",
|
||||
"configurable",
|
||||
]
|
||||
|
||||
DEFAULT_RECURSION_LIMIT = 25
|
||||
|
||||
|
||||
@ -162,15 +169,30 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
)
|
||||
if var_config := var_child_runnable_config.get():
|
||||
empty.update(
|
||||
cast(RunnableConfig, {k: v for k, v in var_config.items() if v is not None})
|
||||
cast(
|
||||
RunnableConfig,
|
||||
{
|
||||
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||
for k, v in var_config.items()
|
||||
if v is not None
|
||||
},
|
||||
)
|
||||
)
|
||||
if config is not None:
|
||||
empty.update(
|
||||
cast(
|
||||
RunnableConfig,
|
||||
{
|
||||
k: v.copy() if k in COPIABLE_KEYS else v # type: ignore[attr-defined]
|
||||
for k, v in config.items()
|
||||
if v is not None and k in CONFIG_KEYS
|
||||
},
|
||||
)
|
||||
)
|
||||
if config is not None:
|
||||
for k, v in config.items():
|
||||
if v is not None:
|
||||
if k in CONFIG_KEYS:
|
||||
empty[k] = v # type: ignore[literal-required]
|
||||
else:
|
||||
empty["configurable"][k] = v
|
||||
if k not in CONFIG_KEYS and v is not None:
|
||||
empty["configurable"][k] = v
|
||||
for key, value in empty.get("configurable", {}).items():
|
||||
if (
|
||||
not key.startswith("__")
|
||||
@ -291,7 +313,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
**(config.get(key) or {}), # type: ignore
|
||||
}
|
||||
elif key == "tags":
|
||||
base[key] = list( # type: ignore
|
||||
base[key] = sorted( # type: ignore
|
||||
set(base.get(key, []) + (config.get(key) or [])), # type: ignore
|
||||
)
|
||||
elif key == "configurable":
|
||||
@ -306,7 +328,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
# so merging two callbacks values has 6 cases
|
||||
if isinstance(these_callbacks, list):
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
base["callbacks"] = these_callbacks.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
base["callbacks"] = base_callbacks + these_callbacks
|
||||
else:
|
||||
@ -318,7 +340,7 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
elif these_callbacks is not None:
|
||||
# these_callbacks is a manager
|
||||
if base_callbacks is None:
|
||||
base["callbacks"] = these_callbacks
|
||||
base["callbacks"] = these_callbacks.copy()
|
||||
elif isinstance(base_callbacks, list):
|
||||
mngr = these_callbacks.copy()
|
||||
for callback in base_callbacks:
|
||||
@ -361,6 +383,8 @@ def merge_configs(*configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
elif key == "recursion_limit":
|
||||
if config["recursion_limit"] != DEFAULT_RECURSION_LIMIT:
|
||||
base["recursion_limit"] = config["recursion_limit"]
|
||||
elif key in COPIABLE_KEYS and config[key] is not None: # type: ignore[literal-required]
|
||||
base[key] = config[key].copy() # type: ignore[literal-required]
|
||||
else:
|
||||
base[key] = config[key] or base.get(key) # type: ignore
|
||||
return base
|
||||
|
@ -1,4 +1,7 @@
|
||||
from typing import Any, cast
|
||||
import json
|
||||
import uuid
|
||||
from contextvars import copy_context
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import pytest
|
||||
|
||||
@ -8,12 +11,61 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
|
||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_set_config_context,
|
||||
ensure_config,
|
||||
merge_configs,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
|
||||
def test_ensure_config() -> None:
|
||||
run_id = str(uuid.uuid4())
|
||||
arg: Dict = {
|
||||
"something": "else",
|
||||
"metadata": {"foo": "bar"},
|
||||
"configurable": {"baz": "qux"},
|
||||
"callbacks": [StdOutCallbackHandler()],
|
||||
"tags": ["tag1", "tag2"],
|
||||
"max_concurrency": 1,
|
||||
"recursion_limit": 100,
|
||||
"run_id": run_id,
|
||||
"run_name": "test",
|
||||
}
|
||||
arg_str = json.dumps({**arg, "callbacks": []})
|
||||
ctx = copy_context()
|
||||
ctx.run(
|
||||
_set_config_context,
|
||||
{
|
||||
"callbacks": [ConsoleCallbackHandler()],
|
||||
"metadata": {"a": "b"},
|
||||
"configurable": {"c": "d"},
|
||||
"tags": ["tag3", "tag4"],
|
||||
},
|
||||
)
|
||||
config = ctx.run(ensure_config, cast(RunnableConfig, arg))
|
||||
assert (
|
||||
len(arg["callbacks"]) == 1
|
||||
), "ensure_config should not modify the original config"
|
||||
assert (
|
||||
json.dumps({**arg, "callbacks": []}) == arg_str
|
||||
), "ensure_config should not modify the original config"
|
||||
assert config is not arg
|
||||
assert config["callbacks"] is not arg["callbacks"]
|
||||
assert config["metadata"] is not arg["metadata"]
|
||||
assert config["configurable"] is not arg["configurable"]
|
||||
assert config == {
|
||||
"tags": ["tag1", "tag2"],
|
||||
"metadata": {"foo": "bar", "baz": "qux", "something": "else"},
|
||||
"callbacks": [arg["callbacks"][0]],
|
||||
"recursion_limit": 100,
|
||||
"configurable": {"baz": "qux", "something": "else"},
|
||||
"max_concurrency": 1,
|
||||
"run_id": run_id,
|
||||
"run_name": "test",
|
||||
}
|
||||
|
||||
|
||||
def test_merge_config_callbacks() -> None:
|
||||
manager: RunnableConfig = {
|
||||
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
||||
|
@ -53,6 +53,8 @@ def test_input_messages() -> None:
|
||||
assert output == "you said: hello"
|
||||
output = with_history.invoke([HumanMessage(content="good bye")], config)
|
||||
assert output == "you said: hello\ngood bye"
|
||||
output = [*with_history.stream([HumanMessage(content="hi again")], config)]
|
||||
assert output == ["you said: hello\ngood bye\nhi again"]
|
||||
assert store == {
|
||||
"1": InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
@ -60,6 +62,8 @@ def test_input_messages() -> None:
|
||||
AIMessage(content="you said: hello"),
|
||||
HumanMessage(content="good bye"),
|
||||
AIMessage(content="you said: hello\ngood bye"),
|
||||
HumanMessage(content="hi again"),
|
||||
AIMessage(content="you said: hello\ngood bye\nhi again"),
|
||||
]
|
||||
)
|
||||
}
|
||||
@ -78,6 +82,10 @@ async def test_input_messages_async() -> None:
|
||||
assert output == "you said: hello"
|
||||
output = await with_history.ainvoke([HumanMessage(content="good bye")], config) # type: ignore[arg-type]
|
||||
assert output == "you said: hello\ngood bye"
|
||||
output = [
|
||||
c
|
||||
async for c in with_history.astream([HumanMessage(content="hi again")], config) # type: ignore[arg-type]
|
||||
] == ["you said: hello\ngood bye\nhi again"]
|
||||
assert store == {
|
||||
"1_async": InMemoryChatMessageHistory(
|
||||
messages=[
|
||||
@ -85,6 +93,8 @@ async def test_input_messages_async() -> None:
|
||||
AIMessage(content="you said: hello"),
|
||||
HumanMessage(content="good bye"),
|
||||
AIMessage(content="you said: hello\ngood bye"),
|
||||
HumanMessage(content="hi again"),
|
||||
AIMessage(content="you said: hello\ngood bye\nhi again"),
|
||||
]
|
||||
)
|
||||
}
|
||||
|
@ -5402,7 +5402,7 @@ def test_listeners() -> None:
|
||||
|
||||
shared_state = {}
|
||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
|
||||
|
||||
def on_start(run: Run) -> None:
|
||||
shared_state[run.id] = {"inputs": run.inputs}
|
||||
@ -5432,7 +5432,7 @@ async def test_listeners_async() -> None:
|
||||
|
||||
shared_state = {}
|
||||
value1 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "one"}, "outputs": {"name": "one"}}
|
||||
value2 = {"inputs": {"name": "two"}, "outputs": {"name": "two"}}
|
||||
|
||||
def on_start(run: Run) -> None:
|
||||
shared_state[run.id] = {"inputs": run.inputs}
|
||||
|
Loading…
Reference in New Issue
Block a user