diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 701da6c375..848d0940a6 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -68,6 +68,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) @@ -89,6 +90,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), + run_name=config.get("run_name"), **kwargs, ) @@ -235,6 +237,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Execute the chain. @@ -276,6 +279,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = callback_manager.on_chain_start( dumpd(self), inputs, + name=run_name, ) try: outputs = ( @@ -302,6 +306,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + run_name: Optional[str] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Asynchronously execute the chain. @@ -343,6 +348,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): run_manager = await callback_manager.on_chain_start( dumpd(self), inputs, + name=run_name, ) try: outputs = ( diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index a91ecd9f2a..6f7dcc2008 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -60,6 +60,7 @@ from langchain.schema.language_model import BaseLanguageModel, LanguageModelInpu from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string from langchain.schema.output import GenerationChunk from langchain.schema.runnable import RunnableConfig +from langchain.schema.runnable.config import get_config_list logger = logging.getLogger(__name__) @@ -265,7 +266,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): max_concurrency: Optional[int] = None, **kwargs: Any, ) -> List[str]: - config = self._get_config_list(config, len(inputs)) + config = get_config_list(config, len(inputs)) if max_concurrency is None: llm_result = self.generate_prompt( @@ -300,7 +301,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): None, self.batch, inputs, config, max_concurrency ) - config = self._get_config_list(config, len(inputs)) + config = get_config_list(config, len(inputs)) if max_concurrency is None: llm_result = await self.agenerate_prompt( diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index bdbd7fc699..88572bfee1 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -42,6 +42,7 @@ from langchain.schema.runnable.config import ( ensure_config, get_async_callback_manager_for_config, get_callback_manager_for_config, + get_config_list, get_executor_for_config, patch_config, ) @@ -110,7 +111,7 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of batch, which calls invoke N times. Subclasses should override this method if they can batch more efficiently. """ - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) # If there's only one input, don't bother with the executor if len(inputs) == 1: @@ -129,7 +130,7 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of abatch, which calls ainvoke N times. Subclasses should override this method if they can batch more efficiently. """ - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) coros = map(partial(self.ainvoke, **kwargs), inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) @@ -210,7 +211,20 @@ class Runnable(Generic[Input, Output], ABC): """ Bind arguments to a Runnable, returning a new Runnable. """ - return RunnableBinding(bound=self, kwargs=kwargs) + return RunnableBinding(bound=self, kwargs=kwargs, config={}) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + """ + Bind config to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, config={**(config or {}), **kwargs}, kwargs={} + ) def map(self) -> Runnable[List[Input], List[Output]]: """ @@ -233,27 +247,6 @@ class Runnable(Generic[Input, Output], ABC): """ --- Helper methods for Subclasses --- """ - def _get_config_list( - self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int - ) -> List[RunnableConfig]: - """ - Helper method to get a list of configs from a single config or a list of - configs, useful for subclasses overriding batch() or abatch(). - """ - if length < 1: - raise ValueError(f"length must be >= 1, but got {length}") - if isinstance(config, list) and len(config) != length: - raise ValueError( - f"config must be a list of the same length as inputs, " - f"but got {len(config)} configs for {length} inputs" - ) - - return ( - list(map(ensure_config, config)) - if isinstance(config, list) - else [patch_config(config, deep_copy_locals=True) for _ in range(length)] - ) - def _call_with_config( self, func: Union[ @@ -273,6 +266,7 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(func): @@ -314,6 +308,7 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), input, run_type=run_type, + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(func): @@ -371,6 +366,7 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, + name=config.get("run_name"), ) try: if accepts_run_manager_and_config(transformer): @@ -451,6 +447,7 @@ class Runnable(Generic[Input, Output], ABC): dumpd(self), {"input": ""}, run_type=run_type, + name=config.get("run_name"), ) try: # mypy can't quite work out thew type guard here, but this is safe, @@ -526,7 +523,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) first_error = None for runnable in self.runnables: try: @@ -558,7 +557,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) first_error = None for runnable in self.runnables: @@ -590,7 +591,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import CallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -606,9 +607,11 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers = [ cm.on_chain_start( - dumpd(self), input if isinstance(input, dict) else {"input": input} + dumpd(self), + input if isinstance(input, dict) else {"input": input}, + name=config.get("run_name"), ) - for cm, input in zip(callback_managers, inputs) + for cm, input, config in zip(callback_managers, inputs, configs) ] first_error = None @@ -648,7 +651,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import AsyncCallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -664,8 +667,12 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ) ) @@ -770,7 +777,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) # invoke all steps in sequence try: @@ -798,7 +807,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) # invoke all steps in sequence try: @@ -825,7 +836,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): from langchain.callbacks.manager import CallbackManager # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ CallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -840,8 +851,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ] # start the root runs, one per input run_managers = [ - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ] # invoke @@ -876,7 +891,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) # setup callbacks - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) callback_managers = [ AsyncCallbackManager.configure( inheritable_callbacks=config.get("callbacks"), @@ -892,8 +907,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # start the root runs, one per input run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( *( - cm.on_chain_start(dumpd(self), input) - for cm, input in zip(callback_managers, inputs) + cm.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + for cm, input, config in zip(callback_managers, inputs, configs) ) ) @@ -929,7 +948,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_callback_manager_for_config(config) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) steps = [self.first] + self.middle + [self.last] streaming_start_index = 0 @@ -996,7 +1017,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) steps = [self.first] + self.middle + [self.last] streaming_start_index = len(steps) - 1 @@ -1127,7 +1150,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): local_metadata=None, ) # start the root run - run_manager = callback_manager.on_chain_start(dumpd(self), input) + run_manager = callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) # gather results from all steps try: @@ -1166,7 +1191,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): config = ensure_config(config) callback_manager = get_async_callback_manager_for_config(config) # start the root run - run_manager = await callback_manager.on_chain_start(dumpd(self), input) + run_manager = await callback_manager.on_chain_start( + dumpd(self), input, name=config.get("run_name") + ) # gather results from all steps try: @@ -1479,6 +1506,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): kwargs: Mapping[str, Any] + config: Mapping[str, Any] = Field(default_factory=dict) + class Config: arbitrary_types_allowed = True @@ -1490,8 +1519,31 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def lc_namespace(self) -> List[str]: return self.__class__.__module__.split(".")[:-1] + def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig: + copy = cast(RunnableConfig, dict(self.config)) + if config: + for key in config: + # Even though the keys aren't literals this is correct + # because both dicts are same type + copy[key] = config[key] or copy.get(key) # type: ignore + return copy + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: - return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs}) + return self.__class__( + bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} + ) + + def with_config( + self, + config: Optional[RunnableConfig] = None, + # Sadly Unpack is not well supported by mypy so this will have to be untyped + **kwargs: Any, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config={**self.config, **(config or {}), **kwargs}, + ) def invoke( self, @@ -1499,7 +1551,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return self.bound.invoke(input, config, **{**self.kwargs, **kwargs}) + return self.bound.invoke( + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, + ) async def ainvoke( self, @@ -1507,7 +1563,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Output: - return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs}) + return await self.bound.ainvoke( + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, + ) def batch( self, @@ -1515,7 +1575,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs}) + if isinstance(config, list): + configs = cast( + List[RunnableConfig], [self._merge_config(conf) for conf in config] + ) + else: + configs = [ + patch_config(self._merge_config(config), deep_copy_locals=True) + for _ in range(len(inputs)) + ] + return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) async def abatch( self, @@ -1523,7 +1592,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Optional[Any], ) -> List[Output]: - return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs}) + if isinstance(config, list): + configs = cast( + List[RunnableConfig], [self._merge_config(conf) for conf in config] + ) + else: + configs = [ + patch_config(self._merge_config(config), deep_copy_locals=True) + for _ in range(len(inputs)) + ] + return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs}) def stream( self, @@ -1531,7 +1609,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Optional[Any], ) -> Iterator[Output]: - yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.stream( + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, + ) async def astream( self, @@ -1540,7 +1622,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: async for item in self.bound.astream( - input, config, **{**self.kwargs, **kwargs} + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, ): yield item @@ -1550,7 +1634,11 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config: Optional[RunnableConfig] = None, **kwargs: Any, ) -> Iterator[Output]: - yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs}) + yield from self.bound.transform( + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, + ) async def atransform( self, @@ -1559,11 +1647,16 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): **kwargs: Any, ) -> AsyncIterator[Output]: async for item in self.bound.atransform( - input, config, **{**self.kwargs, **kwargs} + input, + self._merge_config(config), + **{**self.kwargs, **kwargs}, ): yield item +RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig) + + def coerce_to_runnable( thing: Union[ Runnable[Input, Output], diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index b97d904414..3f87f04403 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -3,7 +3,9 @@ from __future__ import annotations from concurrent.futures import Executor, ThreadPoolExecutor from contextlib import contextmanager from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union + +from typing_extensions import TypedDict if TYPE_CHECKING: from langchain.callbacks.base import BaseCallbackManager, Callbacks @@ -31,6 +33,11 @@ class RunnableConfig(TypedDict, total=False): Tags are passed to all callbacks, metadata is passed to handle*Start callbacks. """ + run_name: str + """ + Name for the tracer run for this call. Defaults to the name of the class. + """ + _locals: Dict[str, Any] """ Local variables @@ -48,7 +55,7 @@ class RunnableConfig(TypedDict, total=False): """ -def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: +def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig: empty = RunnableConfig( tags=[], metadata={}, @@ -61,20 +68,52 @@ def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig: return empty +def get_config_list( + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int +) -> List[RunnableConfig]: + """ + Helper method to get a list of configs from a single config or a list of + configs, useful for subclasses overriding batch() or abatch(). + """ + if length < 1: + raise ValueError(f"length must be >= 1, but got {length}") + if isinstance(config, list) and len(config) != length: + raise ValueError( + f"config must be a list of the same length as inputs, " + f"but got {len(config)} configs for {length} inputs" + ) + + return ( + list(map(ensure_config, config)) + if isinstance(config, list) + else [patch_config(config, deep_copy_locals=True) for _ in range(length)] + ) + + def patch_config( config: Optional[RunnableConfig], *, deep_copy_locals: bool = False, callbacks: Optional[BaseCallbackManager] = None, recursion_limit: Optional[int] = None, + max_concurrency: Optional[int] = None, + run_name: Optional[str] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: config["_locals"] = deepcopy(config["_locals"]) if callbacks is not None: + # If we're replacing callbacks we need to unset run_name + # As that should apply only to the same run as the original callbacks config["callbacks"] = callbacks + if "run_name" in config: + del config["run_name"] if recursion_limit is not None: config["recursion_limit"] = recursion_limit + if max_concurrency is not None: + config["max_concurrency"] = max_concurrency + if run_name is not None: + config["run_name"] = run_name return config diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 68989bfa7d..5277932543 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -23,7 +23,7 @@ from langchain.schema.runnable.base import ( RunnableSequence, coerce_to_runnable, ) -from langchain.schema.runnable.config import RunnableConfig +from langchain.schema.runnable.config import RunnableConfig, get_config_list from langchain.schema.runnable.utils import gather_with_concurrency @@ -131,7 +131,7 @@ class RouterRunnable( raise ValueError("One or more keys do not have a corresponding runnable") runnables = [self.runnables[key] for key in keys] - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) with ThreadPoolExecutor(max_workers=max_concurrency) as executor: return list( executor.map( @@ -156,7 +156,7 @@ class RouterRunnable( raise ValueError("One or more keys do not have a corresponding runnable") runnables = [self.runnables[key] for key in keys] - configs = self._get_config_list(config, len(inputs)) + configs = get_config_list(config, len(inputs)) return await gather_with_concurrency( max_concurrency, *( diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index c48d4edbd4..fcb621fe8c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2081,7 +2081,8 @@ "stop": [ "Thought:" ] - } + }, + "config": {} } }, "llm": { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index f244753310..412fa8e1e7 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -11,6 +11,7 @@ from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.schemas import Run +from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.chat_models.fake import FakeListChatModel from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain.load.dump import dumpd, dumps @@ -112,6 +113,124 @@ class FakeRetriever(BaseRetriever): return [Document(page_content="foo"), Document(page_content="bar")] +@pytest.mark.asyncio +async def test_with_config(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") + + assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5 + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"])), + ] + spy.reset_mock() + + fake_1: Runnable = RunnablePassthrough() + fake_2: Runnable = RunnablePassthrough() + spy_seq_step = mocker.spy(fake_1.__class__, "invoke") + + sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( + tags=["b-tag"], max_concurrency=5 + ) + assert sequence.invoke("hello") == "hello" + assert len(spy_seq_step.call_args_list) == 2 + for i, call in enumerate(spy_seq_step.call_args_list): + assert call.args[1] == "hello" + if i == 0: + assert call.args[2].get("tags") == ["a-tag"] + assert call.args[2].get("max_concurrency") is None + else: + assert call.args[2].get("tags") == ["b-tag"] + assert call.args[2].get("max_concurrency") == 5 + spy_seq_step.reset_mock() + + assert [ + *fake.with_config(tags=["a-tag"]).stream( + "hello", dict(metadata={"key": "value"}) + ) + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})), + ] + spy.reset_mock() + + assert fake.with_config(recursion_limit=5).batch( + ["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})] + ) == [5, 7] + + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + if i == 0: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {} + else: + assert call.args[1].get("recursion_limit") == 5 + assert call.args[1].get("tags") == [] + assert call.args[1].get("metadata") == {"key": "value"} + + spy.reset_mock() + + assert fake.with_config(metadata={"a": "b"}).batch( + ["hello", "wooorld"], dict(tags=["a-tag"]) + ) == [5, 7] + assert len(spy.call_args_list) == 2 + for i, call in enumerate(spy.call_args_list): + assert call.args[0] == ("hello" if i == 0 else "wooorld") + assert call.args[1].get("tags") == ["a-tag"] + assert call.args[1].get("metadata") == {"a": "b"} + spy.reset_mock() + + handler = ConsoleCallbackHandler() + assert ( + await fake.with_config(metadata={"a": "b"}).ainvoke( + "hello", config={"callbacks": [handler]} + ) + == 5 + ) + assert spy.call_args_list == [ + mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})), + ] + spy.reset_mock() + + assert [ + part async for part in fake.with_config(metadata={"a": "b"}).astream("hello") + ] == [5] + assert spy.call_args_list == [ + mocker.call("hello", dict(metadata={"a": "b"})), + ] + spy.reset_mock() + + assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch( + ["hello", "wooorld"], dict(metadata={"key": "value"}) + ) == [ + 5, + 7, + ] + assert spy.call_args_list == [ + mocker.call( + "hello", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + mocker.call( + "wooorld", + dict( + metadata={"key": "value"}, + tags=["c"], + callbacks=None, + _locals={}, + recursion_limit=5, + ), + ), + ] + + @pytest.mark.asyncio async def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() @@ -1125,6 +1244,14 @@ async def test_map_astream_iterator_input() -> None: assert final_value.get("passthrough") == llm_res +def test_with_config_with_config() -> None: + llm = FakeListLLM(responses=["i'm a textbot"]) + + assert dumpd( + llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"]) + ) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]})) + + def test_bind_bind() -> None: llm = FakeListLLM(responses=["i'm a textbot"])