Adjust merge logic

This commit is contained in:
Nuno Campos 2023-08-24 16:49:14 +02:00
parent f95bd0bcd9
commit 5426712311
2 changed files with 40 additions and 18 deletions

View File

@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
copy = cast(RunnableConfig, dict(self.config))
if config:
for key in config:
copy[key] = config[key] or copy.get(key)
return copy
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Output:
return self.bound.invoke(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Output:
return await self.bound.ainvoke(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> List[Output]:
configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config]
[self._merge_config(conf) for conf in config]
if isinstance(config, list)
else [
patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
],
)
@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> List[Output]:
configs = cast(
List[RunnableConfig],
[{**self.config, **(conf or {})} for conf in config]
[self._merge_config(conf) for conf in config]
if isinstance(config, list)
else [
patch_config(
cast(RunnableConfig, {**self.config, **(config or {})}),
deep_copy_locals=True,
)
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
],
)
@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.stream(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
):
yield item
@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> Iterator[Output]:
yield from self.bound.transform(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
)
@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input,
cast(RunnableConfig, {**self.config, **(config or {})}),
self._merge_config(config),
**{**self.kwargs, **kwargs},
):
yield item

View File

@ -1,6 +1,7 @@
from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
import pytest
from freezegun import freeze_time
@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None:
]
spy.reset_mock()
fake_1 = RunnablePassthrough()
fake_2 = RunnablePassthrough()
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
tags=["b-tag"], max_concurrency=5
)
assert sequence.invoke("hello") == "hello"
assert len(spy_seq_step.call_args_list) == 2
for i, call in enumerate(spy_seq_step.call_args_list):
assert call.args[1] == "hello"
if i == 0:
assert call.args[2].get("tags") == ["a-tag"]
assert call.args[2].get("max_concurrency") is None
else:
assert call.args[2].get("tags") == ["b-tag"]
assert call.args[2].get("max_concurrency") == 5
spy_seq_step.reset_mock()
assert [
*fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"})
@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()
handler = ConsoleCallbackHandler()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
"hello", config={"callbacks": []}
"hello", config={"callbacks": [handler]}
)
== 5
)
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
]
spy.reset_mock()