From 542671231172ca66ead99506d1cdd7a952a5cfb4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 16:49:14 +0200 Subject: [PATCH] Adjust merge logic --- .../langchain/schema/runnable/base.py | 33 ++++++++++--------- .../schema/runnable/test_runnable.py | 25 ++++++++++++-- 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 96878c6969..abd7ae81c6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): def lc_namespace(self) -> List[str]: return self.__class__.__module__.split(".")[:-1] + def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig: + copy = cast(RunnableConfig, dict(self.config)) + if config: + for key in config: + copy[key] = config[key] or copy.get(key) + return copy + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs} @@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Output: return self.bound.invoke( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Output: return await self.bound.ainvoke( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> List[Output]: configs = cast( List[RunnableConfig], - [{**self.config, **(conf or {})} for conf in config] + [self._merge_config(conf) for conf in config] if isinstance(config, list) else [ - patch_config( - cast(RunnableConfig, {**self.config, **(config or {})}), - deep_copy_locals=True, - ) + patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ], ) @@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> List[Output]: configs = cast( List[RunnableConfig], - [{**self.config, **(conf or {})} for conf in config] + [self._merge_config(conf) for conf in config] if isinstance(config, list) else [ - patch_config( - cast(RunnableConfig, {**self.config, **(config or {})}), - deep_copy_locals=True, - ) + patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ], ) @@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.stream( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.astream( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ): yield item @@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> Iterator[Output]: yield from self.bound.transform( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ) @@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): ) -> AsyncIterator[Output]: async for item in self.bound.atransform( input, - cast(RunnableConfig, {**self.config, **(config or {})}), + self._merge_config(config), **{**self.kwargs, **kwargs}, ): yield item diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 32f488e10c..6dde96529c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,6 +1,7 @@ from operator import itemgetter from typing import Any, Dict, List, Optional, Union from uuid import UUID +from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler import pytest from freezegun import freeze_time @@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None: ] spy.reset_mock() + fake_1 = RunnablePassthrough() + fake_2 = RunnablePassthrough() + spy_seq_step = mocker.spy(fake_1.__class__, "invoke") + + sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( + tags=["b-tag"], max_concurrency=5 + ) + assert sequence.invoke("hello") == "hello" + assert len(spy_seq_step.call_args_list) == 2 + for i, call in enumerate(spy_seq_step.call_args_list): + assert call.args[1] == "hello" + if i == 0: + assert call.args[2].get("tags") == ["a-tag"] + assert call.args[2].get("max_concurrency") is None + else: + assert call.args[2].get("tags") == ["b-tag"] + assert call.args[2].get("max_concurrency") == 5 + spy_seq_step.reset_mock() + assert [ *fake.with_config(tags=["a-tag"]).stream( "hello", dict(metadata={"key": "value"}) @@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None: assert call.args[1].get("metadata") == {"a": "b"} spy.reset_mock() + handler = ConsoleCallbackHandler() assert ( await fake.with_config(metadata={"a": "b"}).ainvoke( - "hello", config={"callbacks": []} + "hello", config={"callbacks": [handler]} ) == 5 ) assert spy.call_args_list == [ - mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})), + mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})), ] spy.reset_mock()