mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add .with_config() method to Runnables which allows binding any config values to a Runnable
This commit is contained in:
parent
324c86acd5
commit
a3c69cf41d
@ -210,7 +210,20 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
Bind arguments to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
return RunnableBinding(bound=self, kwargs=kwargs, config={})
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind config to a Runnable, returning a new Runnable.
|
||||
"""
|
||||
return RunnableBinding(
|
||||
bound=self, config={**(config or {}), **kwargs}, kwargs={}
|
||||
)
|
||||
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
@ -1479,6 +1492,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
|
||||
kwargs: Mapping[str, Any]
|
||||
|
||||
config: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -1491,7 +1506,21 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
||||
return self.__class__(
|
||||
bound=self.bound, config=self.config, kwargs={**self.kwargs, **kwargs}
|
||||
)
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config={**self.config, **(config or {}), **kwargs},
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
@ -1499,7 +1528,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
return self.bound.invoke(input, config, **{**self.kwargs, **kwargs})
|
||||
return self.bound.invoke(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
@ -1507,7 +1538,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Output:
|
||||
return await self.bound.ainvoke(input, config, **{**self.kwargs, **kwargs})
|
||||
return await self.bound.ainvoke(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
@ -1515,7 +1548,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(inputs, config, **{**self.kwargs, **kwargs})
|
||||
configs = (
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
)
|
||||
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
@ -1523,7 +1564,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(inputs, config, **{**self.kwargs, **kwargs})
|
||||
configs = (
|
||||
[{**self.config, **(conf or {})} for conf in config]
|
||||
if isinstance(config, list)
|
||||
else [
|
||||
patch_config({**self.config, **(config or {})}, deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
)
|
||||
return await self.bound.abatch(
|
||||
inputs,
|
||||
[{**self.config, **(conf or {})} for conf in configs],
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -1531,7 +1584,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.stream(input, config, **{**self.kwargs, **kwargs})
|
||||
yield from self.bound.stream(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
@ -1540,7 +1595,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.astream(
|
||||
input, config, **{**self.kwargs, **kwargs}
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
):
|
||||
yield item
|
||||
|
||||
@ -1550,7 +1605,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[Output]:
|
||||
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
||||
yield from self.bound.transform(
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
)
|
||||
|
||||
async def atransform(
|
||||
self,
|
||||
@ -1559,11 +1616,14 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[Output]:
|
||||
async for item in self.bound.atransform(
|
||||
input, config, **{**self.kwargs, **kwargs}
|
||||
input, {**self.config, **(config or {})}, **{**self.kwargs, **kwargs}
|
||||
):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
|
||||
|
||||
def coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
|
@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
from concurrent.futures import Executor, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
@ -48,7 +49,7 @@ class RunnableConfig(TypedDict, total=False):
|
||||
"""
|
||||
|
||||
|
||||
def ensure_config(config: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
def ensure_config(config: Optional[RunnableConfig] = None) -> RunnableConfig:
|
||||
empty = RunnableConfig(
|
||||
tags=[],
|
||||
metadata={},
|
||||
|
@ -2081,7 +2081,8 @@
|
||||
"stop": [
|
||||
"Thought:"
|
||||
]
|
||||
}
|
||||
},
|
||||
"config": {}
|
||||
}
|
||||
},
|
||||
"llm": {
|
||||
|
@ -112,6 +112,104 @@ class FakeRetriever(BaseRetriever):
|
||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_config(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
spy = mocker.spy(fake, "invoke")
|
||||
|
||||
assert fake.with_config(tags=["a-tag"]).invoke("hello") == 5
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"])),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [
|
||||
*fake.with_config(tags=["a-tag"]).stream(
|
||||
"hello", dict(metadata={"key": "value"})
|
||||
)
|
||||
] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(tags=["a-tag"], metadata={"key": "value"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.with_config(recursion_limit=5).batch(
|
||||
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
|
||||
) == [5, 7]
|
||||
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
if i == 0:
|
||||
assert call.args[1].get("recursion_limit") == 5
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {}
|
||||
else:
|
||||
assert call.args[1].get("recursion_limit") == 5
|
||||
assert call.args[1].get("tags") == []
|
||||
assert call.args[1].get("metadata") == {"key": "value"}
|
||||
|
||||
spy.reset_mock()
|
||||
|
||||
assert fake.with_config(metadata={"a": "b"}).batch(
|
||||
["hello", "wooorld"], dict(tags=["a-tag"])
|
||||
) == [5, 7]
|
||||
assert len(spy.call_args_list) == 2
|
||||
for i, call in enumerate(spy.call_args_list):
|
||||
assert call.args[0] == ("hello" if i == 0 else "wooorld")
|
||||
assert call.args[1].get("tags") == ["a-tag"]
|
||||
assert call.args[1].get("metadata") == {"a": "b"}
|
||||
spy.reset_mock()
|
||||
|
||||
assert (
|
||||
await fake.with_config(metadata={"a": "b"}).ainvoke(
|
||||
"hello", config={"callbacks": []}
|
||||
)
|
||||
== 5
|
||||
)
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(callbacks=[], metadata={"a": "b"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert [
|
||||
part async for part in fake.with_config(metadata={"a": "b"}).astream("hello")
|
||||
] == [5]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call("hello", dict(metadata={"a": "b"})),
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
assert await fake.with_config(recursion_limit=5, tags=["c"]).abatch(
|
||||
["hello", "wooorld"], dict(metadata={"key": "value"})
|
||||
) == [
|
||||
5,
|
||||
7,
|
||||
]
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello",
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=5,
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=["c"],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
recursion_limit=5,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
@ -1125,6 +1223,14 @@ async def test_map_astream_iterator_input() -> None:
|
||||
assert final_value.get("passthrough") == llm_res
|
||||
|
||||
|
||||
def test_with_config_with_config() -> None:
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
assert dumpd(
|
||||
llm.with_config({"metadata": {"a": "b"}}).with_config(tags=["a-tag"])
|
||||
) == dumpd(llm.with_config({"metadata": {"a": "b"}, "tags": ["a-tag"]}))
|
||||
|
||||
|
||||
def test_bind_bind() -> None:
|
||||
llm = FakeListLLM(responses=["i'm a textbot"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user