diff --git a/libs/langchain/langchain/schema/runnable/_locals.py b/libs/langchain/langchain/schema/runnable/_locals.py index 5b2f8e758a..839bb5fbc8 100644 --- a/libs/langchain/langchain/schema/runnable/_locals.py +++ b/libs/langchain/langchain/schema/runnable/_locals.py @@ -48,10 +48,10 @@ class PutLocalVar(RunnablePassthrough): "therefore always receive a non-null config." ) if isinstance(self.key, str): - if self.key not in config["_locals"] or replace: - config["_locals"][self.key] = input + if self.key not in config["locals"] or replace: + config["locals"][self.key] = input else: - config["_locals"][self.key] += input + config["locals"][self.key] += input elif isinstance(self.key, Mapping): if not isinstance(input, Mapping): raise TypeError( @@ -59,10 +59,10 @@ class PutLocalVar(RunnablePassthrough): f"input is expected to be of type Mapping when key is Mapping." ) for input_key, put_key in self.key.items(): - if put_key not in config["_locals"] or replace: - config["_locals"][put_key] = input[input_key] + if put_key not in config["locals"] or replace: + config["locals"][put_key] = input[input_key] else: - config["_locals"][put_key] += input[input_key] + config["locals"][put_key] += input[input_key] else: raise TypeError( f"`key` should be a string or Mapping[str, str], received type " @@ -127,11 +127,11 @@ class GetLocalVar( ) -> Union[Output, Dict[str, Union[Input, Output]]]: if self.passthrough_key: return { - self.key: config["_locals"][self.key], + self.key: config["locals"][self.key], self.passthrough_key: input, } else: - return config["_locals"][self.key] + return config["locals"][self.key] async def _aget( self, diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 6a850ddb4e..1473e58b79 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1628,7 +1628,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): # mark each step as a child run patch_config( config, - deep_copy_locals=True, + copy_locals=True, callbacks=run_manager.get_child(f"map:key:{key}"), ), ) @@ -2111,7 +2111,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) else: configs = [ - patch_config(self._merge_config(config), deep_copy_locals=True) + patch_config(self._merge_config(config), copy_locals=True) for _ in range(len(inputs)) ] return self.bound.batch( @@ -2135,7 +2135,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) else: configs = [ - patch_config(self._merge_config(config), deep_copy_locals=True) + patch_config(self._merge_config(config), copy_locals=True) for _ in range(len(inputs)) ] return await self.bound.abatch( diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 5c89ff150a..6ae120ad7f 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -2,7 +2,6 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager -from copy import deepcopy from typing import ( TYPE_CHECKING, Any, @@ -13,6 +12,7 @@ from typing import ( List, Optional, Union, + cast, ) from typing_extensions import TypedDict @@ -60,9 +60,11 @@ class RunnableConfig(TypedDict, total=False): Name for the tracer run for this call. Defaults to the name of the class. """ - _locals: Dict[str, Any] + locals: Dict[str, Any] """ - Local variables + Variables scoped to this call and any sub-calls. Usually used with + GetLocalVar() and PutLocalVar(). Care should be taken when placing mutable + objects in locals, as they will be shared between parallel sub-calls. """ max_concurrency: Optional[int] @@ -82,11 +84,13 @@ def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: tags=[], metadata={}, callbacks=None, - _locals={}, + locals={}, recursion_limit=10, ) if config is not None: - empty.update(config) + empty.update( + cast(RunnableConfig, {k: v for k, v in config.items() if v is not None}) + ) return empty @@ -108,22 +112,22 @@ def get_config_list( return ( list(map(ensure_config, config)) if isinstance(config, list) - else [patch_config(config, deep_copy_locals=True) for _ in range(length)] + else [patch_config(config, copy_locals=True) for _ in range(length)] ) def patch_config( config: Optional[RunnableConfig], *, - deep_copy_locals: bool = False, + copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, ) -> RunnableConfig: config = ensure_config(config) - if deep_copy_locals: - config["_locals"] = deepcopy(config["_locals"]) + if copy_locals: + config["locals"] = config["locals"].copy() if callbacks is not None: # If we're replacing callbacks we need to unset run_name # As that should apply only to the same run as the original callbacks diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 2a4bd2074e..dc59e65cfd 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -215,7 +215,7 @@ async def test_with_config(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=["c"], callbacks=None, - _locals={}, + locals={}, recursion_limit=5, ), ), @@ -225,7 +225,7 @@ async def test_with_config(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=["c"], callbacks=None, - _locals={}, + locals={}, recursion_limit=5, ), ), @@ -296,7 +296,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=[], callbacks=None, - _locals={}, + locals={}, recursion_limit=10, ), ), @@ -306,7 +306,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: metadata={"key": "value"}, tags=[], callbacks=None, - _locals={}, + locals={}, recursion_limit=10, ), ),