mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
RunnableBranch (#10594)
Runnable Branch implementation, no optimization for streaming logic yet
This commit is contained in:
parent
287c81db89
commit
1eefb9052b
@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
|
|||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableBinding,
|
RunnableBinding,
|
||||||
|
RunnableBranch,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
@ -12,16 +13,17 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough
|
|||||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"patch_config",
|
|
||||||
"GetLocalVar",
|
"GetLocalVar",
|
||||||
|
"patch_config",
|
||||||
"PutLocalVar",
|
"PutLocalVar",
|
||||||
"RouterInput",
|
"RouterInput",
|
||||||
"RouterRunnable",
|
"RouterRunnable",
|
||||||
"Runnable",
|
"Runnable",
|
||||||
"RunnableBinding",
|
"RunnableBinding",
|
||||||
|
"RunnableBranch",
|
||||||
"RunnableConfig",
|
"RunnableConfig",
|
||||||
"RunnableMap",
|
|
||||||
"RunnableLambda",
|
"RunnableLambda",
|
||||||
|
"RunnableMap",
|
||||||
"RunnablePassthrough",
|
"RunnablePassthrough",
|
||||||
"RunnableSequence",
|
"RunnableSequence",
|
||||||
"RunnableWithFallbacks",
|
"RunnableWithFallbacks",
|
||||||
|
@ -658,6 +658,188 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
await run_manager.on_chain_end(final_output, inputs=final_input)
|
await run_manager.on_chain_end(final_output, inputs=final_input)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBranch(Serializable, Runnable[Input, Output]):
|
||||||
|
"""A Runnable that selects which branch to run based on a condition.
|
||||||
|
|
||||||
|
The runnable is initialized with a list of (condition, runnable) pairs and
|
||||||
|
a default branch.
|
||||||
|
|
||||||
|
When operating on an input, the first condition that evaluates to True is
|
||||||
|
selected, and the corresponding runnable is run on the input.
|
||||||
|
|
||||||
|
If no condition evaluates to True, the default branch is run on the input.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.schema.runnable import RunnableBranch
|
||||||
|
|
||||||
|
branch = RunnableBranch(
|
||||||
|
(lambda x: isinstance(x, str), lambda x: x.upper()),
|
||||||
|
(lambda x: isinstance(x, int), lambda x: x + 1),
|
||||||
|
(lambda x: isinstance(x, float), lambda x: x * 2),
|
||||||
|
lambda x: "goodbye",
|
||||||
|
)
|
||||||
|
|
||||||
|
branch.invoke("hello") # "HELLO"
|
||||||
|
branch.invoke(None) # "goodbye"
|
||||||
|
"""
|
||||||
|
|
||||||
|
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
|
||||||
|
default: Runnable[Input, Output]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*branches: Union[
|
||||||
|
Tuple[
|
||||||
|
Union[
|
||||||
|
Runnable[Input, bool],
|
||||||
|
Callable[[Input], bool],
|
||||||
|
Callable[[Input], Awaitable[bool]],
|
||||||
|
],
|
||||||
|
RunnableLike,
|
||||||
|
],
|
||||||
|
RunnableLike, # To accommodate the default branch
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""A Runnable that runs one of two branches based on a condition."""
|
||||||
|
if len(branches) < 2:
|
||||||
|
raise ValueError("RunnableBranch requires at least two branches")
|
||||||
|
|
||||||
|
default = branches[-1]
|
||||||
|
|
||||||
|
if not isinstance(
|
||||||
|
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
|
||||||
|
):
|
||||||
|
raise TypeError(
|
||||||
|
"RunnableBranch default must be runnable, callable or mapping."
|
||||||
|
)
|
||||||
|
|
||||||
|
default_ = cast(
|
||||||
|
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
|
||||||
|
)
|
||||||
|
|
||||||
|
_branches = []
|
||||||
|
|
||||||
|
for branch in branches[:-1]:
|
||||||
|
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
|
||||||
|
raise TypeError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists, not {type(branch)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not len(branch) == 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"RunnableBranch branches must be "
|
||||||
|
f"tuples or lists of length 2, not {len(branch)}"
|
||||||
|
)
|
||||||
|
condition, runnable = branch
|
||||||
|
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
|
||||||
|
runnable = coerce_to_runnable(runnable)
|
||||||
|
_branches.append((condition, runnable))
|
||||||
|
|
||||||
|
super().__init__(branches=_branches, default=default_)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
"""RunnableBranch is serializable if all its branches are serializable."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
|
"""First evaluates the condition, then delegate to true or false branch."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = condition.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
return runnable.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.default.invoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> Output:
|
||||||
|
"""Async version of invoke."""
|
||||||
|
config = ensure_config(config)
|
||||||
|
callback_manager = get_callback_manager_for_config(config)
|
||||||
|
run_manager = callback_manager.on_chain_start(
|
||||||
|
dumpd(self),
|
||||||
|
input,
|
||||||
|
name=config.get("run_name"),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for idx, branch in enumerate(self.branches):
|
||||||
|
condition, runnable = branch
|
||||||
|
|
||||||
|
expression_value = await condition.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if expression_value:
|
||||||
|
return await runnable.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = await self.default.ainvoke(
|
||||||
|
input,
|
||||||
|
config=patch_config(
|
||||||
|
config, callbacks=run_manager.get_child(tag="branch:default")
|
||||||
|
),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
run_manager.on_chain_error(e)
|
||||||
|
raise
|
||||||
|
run_manager.on_chain_end(dumpd(output))
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A Runnable that can fallback to other Runnables if it fails.
|
A Runnable that can fallback to other Runnables if it fails.
|
||||||
@ -2007,14 +2189,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
|
|
||||||
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||||
|
|
||||||
|
RunnableLike = Union[
|
||||||
|
Runnable[Input, Output],
|
||||||
|
Callable[[Input], Output],
|
||||||
|
Callable[[Input], Awaitable[Output]],
|
||||||
|
Mapping[str, Any],
|
||||||
|
]
|
||||||
|
|
||||||
def coerce_to_runnable(
|
|
||||||
thing: Union[
|
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
|
||||||
Runnable[Input, Output],
|
|
||||||
Callable[[Input], Output],
|
|
||||||
Mapping[str, Any],
|
|
||||||
]
|
|
||||||
) -> Runnable[Input, Output]:
|
|
||||||
if isinstance(thing, Runnable):
|
if isinstance(thing, Runnable):
|
||||||
return thing
|
return thing
|
||||||
elif callable(thing):
|
elif callable(thing):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -34,6 +34,7 @@ from langchain.schema.retriever import BaseRetriever
|
|||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
RouterRunnable,
|
RouterRunnable,
|
||||||
Runnable,
|
Runnable,
|
||||||
|
RunnableBranch,
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableMap,
|
RunnableMap,
|
||||||
@ -541,7 +542,7 @@ async def test_prompt_with_llm(
|
|||||||
mocker.stop(prompt_spy)
|
mocker.stop(prompt_spy)
|
||||||
mocker.stop(llm_spy)
|
mocker.stop(llm_spy)
|
||||||
|
|
||||||
# Test stream
|
# Test stream#
|
||||||
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
|
||||||
llm_spy = mocker.spy(llm.__class__, "astream")
|
llm_spy = mocker.spy(llm.__class__, "astream")
|
||||||
tracer = FakeTracer()
|
tracer = FakeTracer()
|
||||||
@ -1816,3 +1817,205 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
|||||||
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||||
assert len(parent_run_qux.child_runs) == 4
|
assert len(parent_run_qux.child_runs) == 4
|
||||||
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_init() -> None:
|
||||||
|
"""Verify that runnable branch gets initialized properly."""
|
||||||
|
add = RunnableLambda(lambda x: x + 1)
|
||||||
|
condition = RunnableLambda(lambda x: x > 0)
|
||||||
|
|
||||||
|
# Test failure with less than 2 branches
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RunnableBranch((condition, add))
|
||||||
|
|
||||||
|
# Test failure with less than 2 branches
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
RunnableBranch(condition)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"branches",
|
||||||
|
[
|
||||||
|
[
|
||||||
|
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
||||||
|
RunnableLambda(lambda x: x - 1),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
|
||||||
|
(RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)),
|
||||||
|
RunnableLambda(lambda x: x - 1),
|
||||||
|
],
|
||||||
|
[
|
||||||
|
(lambda x: x > 0, lambda x: x + 1),
|
||||||
|
(lambda x: x > 5, lambda x: x + 1),
|
||||||
|
lambda x: x - 1,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
||||||
|
"""Verify that runnable branch gets initialized properly."""
|
||||||
|
runnable = RunnableBranch[int, int](*branches)
|
||||||
|
for branch in runnable.branches:
|
||||||
|
condition, body = branch
|
||||||
|
assert isinstance(condition, Runnable)
|
||||||
|
assert isinstance(body, Runnable)
|
||||||
|
|
||||||
|
assert isinstance(runnable.default, Runnable)
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||||
|
"""Verify that runnables are invoked only when necessary."""
|
||||||
|
# Test with single branch
|
||||||
|
add = RunnableLambda(lambda x: x + 1)
|
||||||
|
sub = RunnableLambda(lambda x: x - 1)
|
||||||
|
condition = RunnableLambda(lambda x: x > 0)
|
||||||
|
spy = mocker.spy(condition, "invoke")
|
||||||
|
add_spy = mocker.spy(add, "invoke")
|
||||||
|
|
||||||
|
branch = RunnableBranch[int, int]((condition, add), (condition, add), sub)
|
||||||
|
assert spy.call_count == 0
|
||||||
|
assert add_spy.call_count == 0
|
||||||
|
|
||||||
|
assert branch.invoke(1) == 2
|
||||||
|
assert add_spy.call_count == 1
|
||||||
|
assert spy.call_count == 1
|
||||||
|
|
||||||
|
assert branch.invoke(2) == 3
|
||||||
|
assert spy.call_count == 2
|
||||||
|
assert add_spy.call_count == 2
|
||||||
|
|
||||||
|
assert branch.invoke(-3) == -4
|
||||||
|
# Should fall through to default branch with condition being evaluated twice!
|
||||||
|
assert spy.call_count == 4
|
||||||
|
# Add should not be invoked
|
||||||
|
assert add_spy.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_invoke() -> None:
|
||||||
|
# Test with single branch
|
||||||
|
def raise_value_error(x: int) -> int:
|
||||||
|
"""Raise a value error."""
|
||||||
|
raise ValueError("x is too large")
|
||||||
|
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 100, raise_value_error),
|
||||||
|
# mypy cannot infer types from the lambda
|
||||||
|
(lambda x: x > 0 and x < 5, lambda x: x + 1), # type: ignore[misc]
|
||||||
|
(lambda x: x > 5, lambda x: x * 10),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert branch.invoke(1) == 2
|
||||||
|
assert branch.invoke(10) == 100
|
||||||
|
assert branch.invoke(0) == -1
|
||||||
|
# Should raise an exception
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
branch.invoke(1000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_batch() -> None:
|
||||||
|
"""Test batch variant."""
|
||||||
|
# Test with single branch
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||||
|
(lambda x: x > 5, lambda x: x * 10),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert branch.batch([1, 10, 0]) == [2, 100, -1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runnable_branch_ainvoke() -> None:
|
||||||
|
"""Test async variant of invoke."""
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||||
|
(lambda x: x > 5, lambda x: x * 10),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await branch.ainvoke(1) == 2
|
||||||
|
assert await branch.ainvoke(10) == 100
|
||||||
|
assert await branch.ainvoke(0) == -1
|
||||||
|
|
||||||
|
# Verify that the async variant is used if available
|
||||||
|
async def condition(x: int) -> bool:
|
||||||
|
return x > 0
|
||||||
|
|
||||||
|
async def add(x: int) -> int:
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
async def sub(x: int) -> int:
|
||||||
|
return x - 1
|
||||||
|
|
||||||
|
branch = RunnableBranch[int, int]((condition, add), sub)
|
||||||
|
|
||||||
|
assert await branch.ainvoke(1) == 2
|
||||||
|
assert await branch.ainvoke(-10) == -11
|
||||||
|
|
||||||
|
|
||||||
|
def test_runnable_branch_invoke_callbacks() -> None:
|
||||||
|
"""Verify that callbacks are correctly used in invoke."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
def raise_value_error(x: int) -> int:
|
||||||
|
"""Raise a value error."""
|
||||||
|
raise ValueError("x is too large")
|
||||||
|
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 100, raise_value_error),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert branch.invoke(1, config={"callbacks": [tracer]}) == 0
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].error is None
|
||||||
|
assert tracer.runs[0].outputs == {"output": 0}
|
||||||
|
|
||||||
|
# Check that the chain on end is invoked
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
branch.invoke(1000, config={"callbacks": [tracer]})
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||||
|
assert tracer.runs[1].outputs is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runnable_branch_ainvoke_callbacks() -> None:
|
||||||
|
"""Verify that callbacks are invoked correctly in ainvoke."""
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
async def raise_value_error(x: int) -> int:
|
||||||
|
"""Raise a value error."""
|
||||||
|
raise ValueError("x is too large")
|
||||||
|
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 100, raise_value_error),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0
|
||||||
|
assert len(tracer.runs) == 1
|
||||||
|
assert tracer.runs[0].error is None
|
||||||
|
assert tracer.runs[0].outputs == {"output": 0}
|
||||||
|
|
||||||
|
# Check that the chain on end is invoked
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await branch.ainvoke(1000, config={"callbacks": [tracer]})
|
||||||
|
|
||||||
|
assert len(tracer.runs) == 2
|
||||||
|
assert tracer.runs[1].error == "ValueError('x is too large')"
|
||||||
|
assert tracer.runs[1].outputs is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runnable_branch_abatch() -> None:
|
||||||
|
"""Test async variant of invoke."""
|
||||||
|
branch = RunnableBranch[int, int](
|
||||||
|
(lambda x: x > 0 and x < 5, lambda x: x + 1),
|
||||||
|
(lambda x: x > 5, lambda x: x * 10),
|
||||||
|
lambda x: x - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]
|
||||||
|
Loading…
Reference in New Issue
Block a user