Use shallow copy on runnable locals (#10825)

- deep copy prevents storing complex objects in locals
pull/10842/head
Nuno Campos 11 months ago committed by GitHub
parent ebe08412ad
commit 276125a33b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -48,10 +48,10 @@ class PutLocalVar(RunnablePassthrough):
"therefore always receive a non-null config."
)
if isinstance(self.key, str):
if self.key not in config["_locals"] or replace:
config["_locals"][self.key] = input
if self.key not in config["locals"] or replace:
config["locals"][self.key] = input
else:
config["_locals"][self.key] += input
config["locals"][self.key] += input
elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping):
raise TypeError(
@ -59,10 +59,10 @@ class PutLocalVar(RunnablePassthrough):
f"input is expected to be of type Mapping when key is Mapping."
)
for input_key, put_key in self.key.items():
if put_key not in config["_locals"] or replace:
config["_locals"][put_key] = input[input_key]
if put_key not in config["locals"] or replace:
config["locals"][put_key] = input[input_key]
else:
config["_locals"][put_key] += input[input_key]
config["locals"][put_key] += input[input_key]
else:
raise TypeError(
f"`key` should be a string or Mapping[str, str], received type "
@ -127,11 +127,11 @@ class GetLocalVar(
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key:
return {
self.key: config["_locals"][self.key],
self.key: config["locals"][self.key],
self.passthrough_key: input,
}
else:
return config["_locals"][self.key]
return config["locals"][self.key]
async def _aget(
self,

@ -1628,7 +1628,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# mark each step as a child run
patch_config(
config,
deep_copy_locals=True,
copy_locals=True,
callbacks=run_manager.get_child(f"map:key:{key}"),
),
)
@ -2111,7 +2111,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
)
else:
configs = [
patch_config(self._merge_config(config), deep_copy_locals=True)
patch_config(self._merge_config(config), copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(
@ -2135,7 +2135,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
)
else:
configs = [
patch_config(self._merge_config(config), deep_copy_locals=True)
patch_config(self._merge_config(config), copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(

@ -2,7 +2,6 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
@ -13,6 +12,7 @@ from typing import (
List,
Optional,
Union,
cast,
)
from typing_extensions import TypedDict
@ -60,9 +60,11 @@ class RunnableConfig(TypedDict, total=False):
Name for the tracer run for this call. Defaults to the name of the class.
"""
_locals: Dict[str, Any]
locals: Dict[str, Any]
"""
Local variables
Variables scoped to this call and any sub-calls. Usually used with
GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable
objects in locals, as they will be shared between parallel sub-calls.
"""
max_concurrency: Optional[int]
@ -82,11 +84,13 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
tags=[],
metadata={},
callbacks=None,
_locals={},
locals={},
recursion_limit=10,
)
if config is not None:
empty.update(config)
empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
)
return empty
@ -108,22 +112,22 @@ def get_config_list(
return (
list(map(ensure_config, config))
if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
else [patch_config(config, copy_locals=True) for _ in range(length)]
)
def patch_config(
config: Optional[RunnableConfig],
*,
deep_copy_locals: bool = False,
copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
) -> RunnableConfig:
config = ensure_config(config)
if deep_copy_locals:
config["_locals"] = deepcopy(config["_locals"])
if copy_locals:
config["locals"] = config["locals"].copy()
if callbacks is not None:
# If we're replacing callbacks we need to unset run_name
# As that should apply only to the same run as the original callbacks

@ -215,7 +215,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
locals={},
recursion_limit=5,
),
),
@ -225,7 +225,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
locals={},
recursion_limit=5,
),
),
@ -296,7 +296,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
locals={},
recursion_limit=10,
),
),
@ -306,7 +306,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
locals={},
recursion_limit=10,
),
),

Loading…
Cancel
Save