mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Use shallow copy on runnable locals (#10825)
- deep copy prevents storing complex objects in locals
This commit is contained in:
parent
ebe08412ad
commit
276125a33b
@ -48,10 +48,10 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
"therefore always receive a non-null config."
|
"therefore always receive a non-null config."
|
||||||
)
|
)
|
||||||
if isinstance(self.key, str):
|
if isinstance(self.key, str):
|
||||||
if self.key not in config["_locals"] or replace:
|
if self.key not in config["locals"] or replace:
|
||||||
config["_locals"][self.key] = input
|
config["locals"][self.key] = input
|
||||||
else:
|
else:
|
||||||
config["_locals"][self.key] += input
|
config["locals"][self.key] += input
|
||||||
elif isinstance(self.key, Mapping):
|
elif isinstance(self.key, Mapping):
|
||||||
if not isinstance(input, Mapping):
|
if not isinstance(input, Mapping):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -59,10 +59,10 @@ class PutLocalVar(RunnablePassthrough):
|
|||||||
f"input is expected to be of type Mapping when key is Mapping."
|
f"input is expected to be of type Mapping when key is Mapping."
|
||||||
)
|
)
|
||||||
for input_key, put_key in self.key.items():
|
for input_key, put_key in self.key.items():
|
||||||
if put_key not in config["_locals"] or replace:
|
if put_key not in config["locals"] or replace:
|
||||||
config["_locals"][put_key] = input[input_key]
|
config["locals"][put_key] = input[input_key]
|
||||||
else:
|
else:
|
||||||
config["_locals"][put_key] += input[input_key]
|
config["locals"][put_key] += input[input_key]
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"`key` should be a string or Mapping[str, str], received type "
|
f"`key` should be a string or Mapping[str, str], received type "
|
||||||
@ -127,11 +127,11 @@ class GetLocalVar(
|
|||||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||||
if self.passthrough_key:
|
if self.passthrough_key:
|
||||||
return {
|
return {
|
||||||
self.key: config["_locals"][self.key],
|
self.key: config["locals"][self.key],
|
||||||
self.passthrough_key: input,
|
self.passthrough_key: input,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return config["_locals"][self.key]
|
return config["locals"][self.key]
|
||||||
|
|
||||||
async def _aget(
|
async def _aget(
|
||||||
self,
|
self,
|
||||||
|
@ -1628,7 +1628,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
# mark each step as a child run
|
# mark each step as a child run
|
||||||
patch_config(
|
patch_config(
|
||||||
config,
|
config,
|
||||||
deep_copy_locals=True,
|
copy_locals=True,
|
||||||
callbacks=run_manager.get_child(f"map:key:{key}"),
|
callbacks=run_manager.get_child(f"map:key:{key}"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -2111,7 +2111,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [
|
||||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
patch_config(self._merge_config(config), copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return self.bound.batch(
|
return self.bound.batch(
|
||||||
@ -2135,7 +2135,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs = [
|
configs = [
|
||||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
patch_config(self._merge_config(config), copy_locals=True)
|
||||||
for _ in range(len(inputs))
|
for _ in range(len(inputs))
|
||||||
]
|
]
|
||||||
return await self.bound.abatch(
|
return await self.bound.abatch(
|
||||||
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -13,6 +12,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -60,9 +60,11 @@ class RunnableConfig(TypedDict, total=False):
|
|||||||
Name for the tracer run for this call. Defaults to the name of the class.
|
Name for the tracer run for this call. Defaults to the name of the class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_locals: Dict[str, Any]
|
locals: Dict[str, Any]
|
||||||
"""
|
"""
|
||||||
Local variables
|
Variables scoped to this call and any sub-calls. Usually used with
|
||||||
|
GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable
|
||||||
|
objects in locals, as they will be shared between parallel sub-calls.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_concurrency: Optional[int]
|
max_concurrency: Optional[int]
|
||||||
@ -82,11 +84,13 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
|||||||
tags=[],
|
tags=[],
|
||||||
metadata={},
|
metadata={},
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
locals={},
|
||||||
recursion_limit=10,
|
recursion_limit=10,
|
||||||
)
|
)
|
||||||
if config is not None:
|
if config is not None:
|
||||||
empty.update(config)
|
empty.update(
|
||||||
|
cast(RunnableConfig, {k: v for k, v in config.items() if v is not None})
|
||||||
|
)
|
||||||
return empty
|
return empty
|
||||||
|
|
||||||
|
|
||||||
@ -108,22 +112,22 @@ def get_config_list(
|
|||||||
return (
|
return (
|
||||||
list(map(ensure_config, config))
|
list(map(ensure_config, config))
|
||||||
if isinstance(config, list)
|
if isinstance(config, list)
|
||||||
else [patch_config(config, deep_copy_locals=True) for _ in range(length)]
|
else [patch_config(config, copy_locals=True) for _ in range(length)]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def patch_config(
|
def patch_config(
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
*,
|
*,
|
||||||
deep_copy_locals: bool = False,
|
copy_locals: bool = False,
|
||||||
callbacks: Optional[BaseCallbackManager] = None,
|
callbacks: Optional[BaseCallbackManager] = None,
|
||||||
recursion_limit: Optional[int] = None,
|
recursion_limit: Optional[int] = None,
|
||||||
max_concurrency: Optional[int] = None,
|
max_concurrency: Optional[int] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
) -> RunnableConfig:
|
) -> RunnableConfig:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
if deep_copy_locals:
|
if copy_locals:
|
||||||
config["_locals"] = deepcopy(config["_locals"])
|
config["locals"] = config["locals"].copy()
|
||||||
if callbacks is not None:
|
if callbacks is not None:
|
||||||
# If we're replacing callbacks we need to unset run_name
|
# If we're replacing callbacks we need to unset run_name
|
||||||
# As that should apply only to the same run as the original callbacks
|
# As that should apply only to the same run as the original callbacks
|
||||||
|
@ -215,7 +215,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=["c"],
|
tags=["c"],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
locals={},
|
||||||
recursion_limit=5,
|
recursion_limit=5,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -225,7 +225,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=["c"],
|
tags=["c"],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
locals={},
|
||||||
recursion_limit=5,
|
recursion_limit=5,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -296,7 +296,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
locals={},
|
||||||
recursion_limit=10,
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@ -306,7 +306,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
metadata={"key": "value"},
|
metadata={"key": "value"},
|
||||||
tags=[],
|
tags=[],
|
||||||
callbacks=None,
|
callbacks=None,
|
||||||
_locals={},
|
locals={},
|
||||||
recursion_limit=10,
|
recursion_limit=10,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user