Use shallow copy on runnable locals (#10825)

- deep copy prevents storing complex objects in locals
This commit is contained in:
Nuno Campos 2023-09-20 16:13:06 +01:00 committed by GitHub
parent ebe08412ad
commit 276125a33b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 24 deletions

View File

@ -48,10 +48,10 @@ class PutLocalVar(RunnablePassthrough):
"therefore always receive a non-null config." "therefore always receive a non-null config."
) )
if isinstance(self.key, str): if isinstance(self.key, str):
if self.key not in config["_locals"] or replace: if self.key not in config["locals"] or replace:
config["_locals"][self.key] = input config["locals"][self.key] = input
else: else:
config["_locals"][self.key] += input config["locals"][self.key] += input
elif isinstance(self.key, Mapping): elif isinstance(self.key, Mapping):
if not isinstance(input, Mapping): if not isinstance(input, Mapping):
raise TypeError( raise TypeError(
@ -59,10 +59,10 @@ class PutLocalVar(RunnablePassthrough):
f"input is expected to be of type Mapping when key is Mapping." f"input is expected to be of type Mapping when key is Mapping."
) )
for input_key, put_key in self.key.items(): for input_key, put_key in self.key.items():
if put_key not in config["_locals"] or replace: if put_key not in config["locals"] or replace:
config["_locals"][put_key] = input[input_key] config["locals"][put_key] = input[input_key]
else: else:
config["_locals"][put_key] += input[input_key] config["locals"][put_key] += input[input_key]
else: else:
raise TypeError( raise TypeError(
f"`key` should be a string or Mapping[str, str], received type " f"`key` should be a string or Mapping[str, str], received type "
@ -127,11 +127,11 @@ class GetLocalVar(
) -> Union[Output, Dict[str, Union[Input, Output]]]: ) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key: if self.passthrough_key:
return { return {
self.key: config["_locals"][self.key], self.key: config["locals"][self.key],
self.passthrough_key: input, self.passthrough_key: input,
} }
else: else:
return config["_locals"][self.key] return config["locals"][self.key]
async def _aget( async def _aget(
self, self,

View File

@ -1628,7 +1628,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
# mark each step as a child run # mark each step as a child run
patch_config( patch_config(
config, config,
deep_copy_locals=True, copy_locals=True,
callbacks=run_manager.get_child(f"map:key:{key}"), callbacks=run_manager.get_child(f"map:key:{key}"),
), ),
) )
@ -2111,7 +2111,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) )
else: else:
configs = [ configs = [
patch_config(self._merge_config(config), deep_copy_locals=True) patch_config(self._merge_config(config), copy_locals=True)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ]
return self.bound.batch( return self.bound.batch(
@ -2135,7 +2135,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) )
else: else:
configs = [ configs = [
patch_config(self._merge_config(config), deep_copy_locals=True) patch_config(self._merge_config(config), copy_locals=True)
for _ in range(len(inputs)) for _ in range(len(inputs))
] ]
return await self.bound.abatch( return await self.bound.abatch(

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -13,6 +12,7 @@ from typing import (
List, List,
Optional, Optional,
Union, Union,
cast,
) )
from typing_extensions import TypedDict from typing_extensions import TypedDict
@ -60,9 +60,11 @@ class RunnableConfig(TypedDict, total=False):
Name for the tracer run for this call. Defaults to the name of the class. Name for the tracer run for this call. Defaults to the name of the class.
""" """
_locals: Dict[str, Any] locals: Dict[str, Any]
""" """
Local variables Variables scoped to this call and any sub-calls. Usually used with
GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable
objects in locals, as they will be shared between parallel sub-calls.
""" """
max_concurrency: Optional[int] max_concurrency: Optional[int]
@ -82,11 +84,13 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
tags=[], tags=[],
metadata={}, metadata={},
callbacks=None, callbacks=None,
_locals={}, locals={},
recursion_limit=10, recursion_limit=10,
) )
if config is not None: if config is not None:
empty.update(config) empty.update(
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
)
return empty return empty
@ -108,22 +112,22 @@ def get_config_list(
return ( return (
list(map(ensure_config, config)) list(map(ensure_config, config))
if isinstance(config, list) if isinstance(config, list)
else [patch_config(config, deep_copy_locals=True) for _ in range(length)] else [patch_config(config, copy_locals=True) for _ in range(length)]
) )
def patch_config( def patch_config(
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
*, *,
deep_copy_locals: bool = False, copy_locals: bool = False,
callbacks: Optional[BaseCallbackManager] = None, callbacks: Optional[BaseCallbackManager] = None,
recursion_limit: Optional[int] = None, recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None, max_concurrency: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
) -> RunnableConfig: ) -> RunnableConfig:
config = ensure_config(config) config = ensure_config(config)
if deep_copy_locals: if copy_locals:
config["_locals"] = deepcopy(config["_locals"]) config["locals"] = config["locals"].copy()
if callbacks is not None: if callbacks is not None:
# If we're replacing callbacks we need to unset run_name # If we're replacing callbacks we need to unset run_name
# As that should apply only to the same run as the original callbacks # As that should apply only to the same run as the original callbacks

View File

@ -215,7 +215,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
_locals={}, locals={},
recursion_limit=5, recursion_limit=5,
), ),
), ),
@ -225,7 +225,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=["c"], tags=["c"],
callbacks=None, callbacks=None,
_locals={}, locals={},
recursion_limit=5, recursion_limit=5,
), ),
), ),
@ -296,7 +296,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, locals={},
recursion_limit=10, recursion_limit=10,
), ),
), ),
@ -306,7 +306,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
metadata={"key": "value"}, metadata={"key": "value"},
tags=[], tags=[],
callbacks=None, callbacks=None,
_locals={}, locals={},
recursion_limit=10, recursion_limit=10,
), ),
), ),