mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Adjust merge logic
This commit is contained in:
parent
f95bd0bcd9
commit
5426712311
@ -1527,6 +1527,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def _merge_config(self, config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
copy = cast(RunnableConfig, dict(self.config))
|
||||
if config:
|
||||
for key in config:
|
||||
copy[key] = config[key] or copy.get(key)
|
||||
return copy
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||
@ -1552,7 +1559,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> Output:
|
||||
return self.bound.invoke(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -1564,7 +1571,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -1576,13 +1583,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
[self._merge_config(conf) for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config(
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
deep_copy_locals=True,
|
||||
)
|
||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
],
|
||||
)
|
||||
@ -1596,13 +1600,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
configs = cast(
|
||||
List[RunnableConfig],
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
[self._merge_config(conf) for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config(
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
deep_copy_locals=True,
|
||||
)
|
||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
],
|
||||
)
|
||||
@ -1616,7 +1617,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -1628,7 +1629,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
@ -1641,7 +1642,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
@ -1653,7 +1654,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(
|
||||
input,
|
||||
cast(RunnableConfig, {**self.config, **(config or {})}),
|
||||
self._merge_config(config),
|
||||
**{**self.kwargs, **kwargs},
|
||||
):
|
||||
yield item
|
||||
|
@ -1,6 +1,7 @@
|
||||
from operator import itemgetter
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from uuid import UUID
|
||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@ -123,6 +124,25 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
fake_1 = RunnablePassthrough()
|
||||
fake_2 = RunnablePassthrough()
|
||||
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
|
||||
|
||||
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
|
||||
tags=["b-tag"], max_concurrency=5
|
||||
)
|
||||
assert sequence.invoke("hello") == "hello"
|
||||
assert len(spy_seq_step.call_args_list) == 2
|
||||
for i, call in enumerate(spy_seq_step.call_args_list):
|
||||
assert call.args[1] == "hello"
|
||||
if i == 0:
|
||||
assert call.args[2].get("tags") == ["a-tag"]
|
||||
assert call.args[2].get("max_concurrency") is None
|
||||
else:
|
||||
assert call.args[2].get("tags") == ["b-tag"]
|
||||
assert call.args[2].get("max_concurrency") == 5
|
||||
spy_seq_step.reset_mock()
|
||||
|
||||
assert [
|
||||
*fake.with_config(tags=["a-tag"]).stream(
|
||||
"hello", dict(metadata={"key": "value"})
|
||||
@ -161,14 +181,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
|
||||
assert call.args[1].get("metadata") == {"a": "b"}
|
||||
spy.reset_mock()
|
||||
|
||||
handler = ConsoleCallbackHandler()
|
||||
assert (
|
||||
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||
"hello", config={"callbacks": []}
|
||||
"hello", config={"callbacks": [handler]}
|
||||
)
|
||||
== 5
|
||||
)
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
|
||||
mocker.call("hello", dict(callbacks=[handler], metadata={"a": "b"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user