diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 301b0143e7..a490c58315 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -62,7 +62,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Dict[str, Any]: - return self(input, **(config or {}), **kwargs) + _config: Dict[str, Any] = dict(config) if config else {} + _config.pop("_locals", None) + return self(input, **_config, **kwargs) async def ainvoke( self, @@ -76,7 +78,9 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): None, partial(self.invoke, input, config, **kwargs) ) - return await self.acall(input, **(config or {}), **kwargs) + _config: Dict[str, Any] = dict(config) if config else {} + _config.pop("_locals", None) + return await self.acall(input, **_config, **kwargs) memory: Optional[BaseMemory] = None """Optional memory object. Defaults to None. diff --git a/libs/langchain/langchain/schema/runnable.py b/libs/langchain/langchain/schema/runnable.py index 8edafe4599..eebd5a96aa 100644 --- a/libs/langchain/langchain/schema/runnable.py +++ b/libs/langchain/langchain/schema/runnable.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy from itertools import tee from typing import ( Any, @@ -66,6 +67,35 @@ class RunnableConfig(TypedDict, total=False): Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ + _locals: Dict[str, Any] + """ + Local variables + """ + + +def _empty_config() -> RunnableConfig: + return RunnableConfig(tags=[], metadata={}, callbacks=None, _locals={}) + + +def _get_callback_manager(config: Mapping) -> Any: + from langchain.callbacks.manager import CallbackManager + + return CallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + + +def _get_async_callback_manager(config: Mapping) -> Any: + from langchain.callbacks.manager import AsyncCallbackManager + + return AsyncCallbackManager.configure( + inheritable_callbacks=config.get("callbacks"), + inheritable_tags=config.get("tags"), + inheritable_metadata=config.get("metadata"), + ) + Input = TypeVar("Input") # Output type should implement __concat__, as eg str, list, dict do @@ -243,7 +273,7 @@ class Runnable(Generic[Input, Output], ABC): return ( config if isinstance(config, list) - else [config.copy() if config is not None else {} for _ in range(length)] + else [deepcopy(config) if config is not None else {} for _ in range(length)] ) def _call_with_config( @@ -255,14 +285,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" - from langchain.callbacks.manager import CallbackManager - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -288,14 +312,8 @@ class Runnable(Generic[Input, Output], ABC): ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input}, @@ -322,8 +340,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. Use this to implement `stream()` or `transform()` in Runnable subclasses.""" - from langchain.callbacks.manager import CallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = tee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -333,11 +349,7 @@ class Runnable(Generic[Input, Output], ABC): final_output_supported = True config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_callback_manager(config) run_manager = callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -393,8 +405,6 @@ class Runnable(Generic[Input, Output], ABC): """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. Use this to implement `astream()` or `atransform()` in Runnable subclasses.""" - from langchain.callbacks.manager import AsyncCallbackManager - # tee the input so we can iterate over it twice input_for_tracing, input_for_transform = atee(input, 2) # Start the input iterator to ensure the input runnable starts before this one @@ -404,11 +414,7 @@ class Runnable(Generic[Input, Output], ABC): final_output_supported = True config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - inheritable_tags=config.get("tags"), - inheritable_metadata=config.get("metadata"), - ) + callback_manager = _get_async_callback_manager(config) run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": ""}, @@ -473,19 +479,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): yield from self.fallbacks def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -516,19 +512,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -751,19 +737,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: - from langchain.callbacks.manager import CallbackManager - # setup callbacks - config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + config = config or _empty_config() + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -771,11 +747,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke all steps in sequence try: + callbacks = run_manager.get_child() for step in self.steps: input = step.invoke( input, # mark each step as a child run - _patch_config(config, run_manager.get_child()), + _patch_config(config, callbacks), ) # finish the root run except (KeyboardInterrupt, Exception) as e: @@ -790,19 +767,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -946,19 +913,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): def stream( self, input: Input, config: Optional[RunnableConfig] = None ) -> Iterator[Output]: - from langchain.callbacks.manager import CallbackManager - # setup callbacks config = config or {} - callback_manager = CallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_callback_manager(config) # start the root run run_manager = callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1023,19 +980,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): async def astream( self, input: Input, config: Optional[RunnableConfig] = None ) -> AsyncIterator[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), input if isinstance(input, dict) else {"input": input} @@ -1173,19 +1120,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None ) -> Dict[str, Any]: - from langchain.callbacks.manager import AsyncCallbackManager - # setup callbacks config = config or {} - callback_manager = AsyncCallbackManager.configure( - inheritable_callbacks=config.get("callbacks"), - local_callbacks=None, - verbose=False, - inheritable_tags=config.get("tags"), - local_tags=None, - inheritable_metadata=config.get("metadata"), - local_metadata=None, - ) + callback_manager = _get_async_callback_manager(config) # start the root run run_manager = await callback_manager.on_chain_start( dumpd(self), {"input": input} @@ -1464,10 +1401,11 @@ class RouterRunnable( def _patch_config( - config: RunnableConfig, callback_manager: BaseCallbackManager + config: RunnableConfig, callback_manager: BaseCallbackManager, _locals: Optional[Dict[str, Any]] = None ) -> RunnableConfig: - config = config.copy() + config = deepcopy(config) config["callbacks"] = callback_manager + config["_locals"] = _locals or {} return config diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index 5b3d5775c4..be55f6f99a 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -636,7 +636,9 @@ async def _arun_chain( else: output = await chain.acall(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, _locals={} + ) output = await chain.ainvoke(inputs_, config=runnable_config) return output @@ -957,7 +959,9 @@ def _run_chain( else: output = chain(inputs_, callbacks=callbacks, tags=tags) else: - runnable_config = RunnableConfig(tags=tags or [], callbacks=callbacks) + runnable_config = RunnableConfig( + tags=tags or [], callbacks=callbacks, _locals={} + ) output = chain.invoke(inputs_, config=runnable_config) return output