Add .with_config() method to Runnables which allows binding any config values to a Runnable

This commit is contained in:
Nuno Campos 2023-08-24 11:53:29 +02:00
parent 324c86acd5
commit a3c69cf41d
4 changed files with 181 additions and 13 deletions

View File

@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC):
"""
Bind arguments to a Runnable, returning a new Runnable.
"""
return RunnableBinding(bound=self, kwargs=kwargs)
return RunnableBinding(bound=self, kwargs=kwargs, config={})
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
"""
Bind config to a Runnable, returning a new Runnable.
"""
return RunnableBinding(
bound=self, config={**(config or {}), **kwargs}, kwargs={}
)
def map(self) -> Runnable[List[Input], List[Output]]:
"""
@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
kwargs: Mapping[str, Any]
config: Mapping[str, Any] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
return self.__class__.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
return self.__class__(
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config={**self.config, **(config or {}), **kwargs},
)
def invoke(
self,
@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs})
return self.bound.invoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def ainvoke(
self,
@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs})
return await self.bound.ainvoke(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
def batch(
self,
@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
configs = (
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
for _ in range(len(inputs))
]
)
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
async def abatch(
self,
@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Optional[Any],
) -> List[Output]:
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
configs = (
[{**self.config, **(conf or {})} for conf in config]
if isinstance(config, list)
else [
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
for _ in range(len(inputs))
]
)
return await self.bound.abatch(
inputs,
[{**self.config, **(conf or {})} for conf in configs],
**{**self.kwargs, **kwargs},
)
def stream(
self,
@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs})
yield from self.bound.stream(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def astream(
self,
@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for item in self.bound.astream(
input, config, **{**self.kwargs, **kwargs}
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
):
yield item
@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
yield from self.bound.transform(
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
)
async def atransform(
self,
@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
**kwargs: Any,
) -> AsyncIterator[Output]:
async for item in self.bound.atransform(
input, config, **{**self.kwargs, **kwargs}
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
):
yield item
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],

View File

@ -3,7 +3,8 @@ from __future__ import annotations
from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
from typing_extensions import TypedDict
if TYPE_CHECKING:
from langchain.callbacks.base import BaseCallbackManager, Callbacks
@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False):
"""
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
empty = RunnableConfig(
tags=[],
metadata={},

View File

@ -2081,7 +2081,8 @@
"stop": [
"Thought:"
]
}
},
"config": {}
}
},
"llm": {

View File

@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever):
return [Document(page_content="foo"), Document(page_content="bar")]
@pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert [
*fake.with_config(tags=["a-tag"]).stream(
"hello", dict(metadata={"key": "value"})
)
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.with_config(recursion_limit=5).batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}
spy.reset_mock()
assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
"hello", config={"callbacks": []}
)
== 5
)
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
]
spy.reset_mock()
assert [
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"a": "b"})),
]
spy.reset_mock()
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
["hello", "wooorld"], dict(metadata={"key": "value"})
) == [
5,
7,
]
assert spy.call_args_list == [
mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
_locals={},
recursion_limit=5,
),
),
]
@pytest.mark.asyncio
async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable()
@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None:
assert final_value.get("passthrough") == llm_res
def test_with_config_with_config() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])
assert dumpd(
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
def test_bind_bind() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])