RunnableBranch (#10594)

Runnable Branch implementation, no optimization for streaming logic yet
pull/10495/head
Eugene Yurtsev 11 months ago committed by GitHub
parent 287c81db89
commit 1eefb9052b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableLambda,
RunnableMap,
RunnableSequence,
@ -12,16 +13,17 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"patch_config",
"GetLocalVar",
"patch_config",
"PutLocalVar",
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
"RunnableMap",
"RunnableLambda",
"RunnableMap",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",

@ -658,6 +658,188 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(final_output, inputs=final_input)
class RunnableBranch(Serializable, Runnable[Input, Output]):
"""A Runnable that selects which branch to run based on a condition.
The runnable is initialized with a list of (condition, runnable) pairs and
a default branch.
When operating on an input, the first condition that evaluates to True is
selected, and the corresponding runnable is run on the input.
If no condition evaluates to True, the default branch is run on the input.
Examples:
.. code-block:: python
from langchain.schema.runnable import RunnableBranch
branch = RunnableBranch(
(lambda x: isinstance(x, str), lambda x: x.upper()),
(lambda x: isinstance(x, int), lambda x: x + 1),
(lambda x: isinstance(x, float), lambda x: x * 2),
lambda x: "goodbye",
)
branch.invoke("hello") # "HELLO"
branch.invoke(None) # "goodbye"
"""
branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]]
default: Runnable[Input, Output]
def __init__(
self,
*branches: Union[
Tuple[
Union[
Runnable[Input, bool],
Callable[[Input], bool],
Callable[[Input], Awaitable[bool]],
],
RunnableLike,
],
RunnableLike, # To accommodate the default branch
],
) -> None:
"""A Runnable that runs one of two branches based on a condition."""
if len(branches) < 2:
raise ValueError("RunnableBranch requires at least two branches")
default = branches[-1]
if not isinstance(
default, (Runnable, Callable, Mapping) # type: ignore[arg-type]
):
raise TypeError(
"RunnableBranch default must be runnable, callable or mapping."
)
default_ = cast(
Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default))
)
_branches = []
for branch in branches[:-1]:
if not isinstance(branch, (tuple, list)): # type: ignore[arg-type]
raise TypeError(
f"RunnableBranch branches must be "
f"tuples or lists, not {type(branch)}"
)
if not len(branch) == 2:
raise ValueError(
f"RunnableBranch branches must be "
f"tuples or lists of length 2, not {len(branch)}"
)
condition, runnable = branch
condition = cast(Runnable[Input, bool], coerce_to_runnable(condition))
runnable = coerce_to_runnable(runnable)
_branches.append((condition, runnable))
super().__init__(branches=_branches, default=default_)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
"""RunnableBranch is serializable if all its branches are serializable."""
return True
@property
def lc_namespace(self) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return self.__class__.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = condition.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
),
)
if expression_value:
return runnable.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
),
)
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Async version of invoke."""
config = ensure_config(config)
callback_manager = get_callback_manager_for_config(config)
run_manager = callback_manager.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
)
try:
for idx, branch in enumerate(self.branches):
condition, runnable = branch
expression_value = await condition.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag=f"condition:{idx}")
),
)
if expression_value:
return await runnable.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag=f"branch:{idx}")
),
**kwargs,
)
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise
run_manager.on_chain_end(dumpd(output))
return output
class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
"""
A Runnable that can fallback to other Runnables if it fails.
@ -2007,14 +2189,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableLike = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Mapping[str, Any],
]
def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Mapping[str, Any],
]
) -> Runnable[Input, Output]:
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif callable(thing):

@ -1,5 +1,5 @@
from operator import itemgetter
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
import pytest
@ -34,6 +34,7 @@ from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
RouterRunnable,
Runnable,
RunnableBranch,
RunnableConfig,
RunnableLambda,
RunnableMap,
@ -541,7 +542,7 @@ async def test_prompt_with_llm(
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test stream
# Test stream#
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
@ -1816,3 +1817,205 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
def test_runnable_branch_init() -> None:
"""Verify that runnable branch gets initialized properly."""
add = RunnableLambda(lambda x: x + 1)
condition = RunnableLambda(lambda x: x > 0)
# Test failure with less than 2 branches
with pytest.raises(ValueError):
RunnableBranch((condition, add))
# Test failure with less than 2 branches
with pytest.raises(ValueError):
RunnableBranch(condition)
@pytest.mark.parametrize(
"branches",
[
[
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
RunnableLambda(lambda x: x - 1),
],
[
(RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)),
(RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)),
RunnableLambda(lambda x: x - 1),
],
[
(lambda x: x > 0, lambda x: x + 1),
(lambda x: x > 5, lambda x: x + 1),
lambda x: x - 1,
],
],
)
def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
"""Verify that runnable branch gets initialized properly."""
runnable = RunnableBranch[int, int](*branches)
for branch in runnable.branches:
condition, body = branch
assert isinstance(condition, Runnable)
assert isinstance(body, Runnable)
assert isinstance(runnable.default, Runnable)
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
"""Verify that runnables are invoked only when necessary."""
# Test with single branch
add = RunnableLambda(lambda x: x + 1)
sub = RunnableLambda(lambda x: x - 1)
condition = RunnableLambda(lambda x: x > 0)
spy = mocker.spy(condition, "invoke")
add_spy = mocker.spy(add, "invoke")
branch = RunnableBranch[int, int]((condition, add), (condition, add), sub)
assert spy.call_count == 0
assert add_spy.call_count == 0
assert branch.invoke(1) == 2
assert add_spy.call_count == 1
assert spy.call_count == 1
assert branch.invoke(2) == 3
assert spy.call_count == 2
assert add_spy.call_count == 2
assert branch.invoke(-3) == -4
# Should fall through to default branch with condition being evaluated twice!
assert spy.call_count == 4
# Add should not be invoked
assert add_spy.call_count == 2
def test_runnable_branch_invoke() -> None:
# Test with single branch
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
branch = RunnableBranch[int, int](
(lambda x: x > 100, raise_value_error),
# mypy cannot infer types from the lambda
(lambda x: x > 0 and x < 5, lambda x: x + 1), # type: ignore[misc]
(lambda x: x > 5, lambda x: x * 10),
lambda x: x - 1,
)
assert branch.invoke(1) == 2
assert branch.invoke(10) == 100
assert branch.invoke(0) == -1
# Should raise an exception
with pytest.raises(ValueError):
branch.invoke(1000)
def test_runnable_branch_batch() -> None:
"""Test batch variant."""
# Test with single branch
branch = RunnableBranch[int, int](
(lambda x: x > 0 and x < 5, lambda x: x + 1),
(lambda x: x > 5, lambda x: x * 10),
lambda x: x - 1,
)
assert branch.batch([1, 10, 0]) == [2, 100, -1]
@pytest.mark.asyncio
async def test_runnable_branch_ainvoke() -> None:
"""Test async variant of invoke."""
branch = RunnableBranch[int, int](
(lambda x: x > 0 and x < 5, lambda x: x + 1),
(lambda x: x > 5, lambda x: x * 10),
lambda x: x - 1,
)
assert await branch.ainvoke(1) == 2
assert await branch.ainvoke(10) == 100
assert await branch.ainvoke(0) == -1
# Verify that the async variant is used if available
async def condition(x: int) -> bool:
return x > 0
async def add(x: int) -> int:
return x + 1
async def sub(x: int) -> int:
return x - 1
branch = RunnableBranch[int, int]((condition, add), sub)
assert await branch.ainvoke(1) == 2
assert await branch.ainvoke(-10) == -11
def test_runnable_branch_invoke_callbacks() -> None:
"""Verify that callbacks are correctly used in invoke."""
tracer = FakeTracer()
def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
branch = RunnableBranch[int, int](
(lambda x: x > 100, raise_value_error),
lambda x: x - 1,
)
assert branch.invoke(1, config={"callbacks": [tracer]}) == 0
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked
with pytest.raises(ValueError):
branch.invoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')"
assert tracer.runs[1].outputs is None
@pytest.mark.asyncio
async def test_runnable_branch_ainvoke_callbacks() -> None:
"""Verify that callbacks are invoked correctly in ainvoke."""
tracer = FakeTracer()
async def raise_value_error(x: int) -> int:
"""Raise a value error."""
raise ValueError("x is too large")
branch = RunnableBranch[int, int](
(lambda x: x > 100, raise_value_error),
lambda x: x - 1,
)
assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
assert tracer.runs[0].outputs == {"output": 0}
# Check that the chain on end is invoked
with pytest.raises(ValueError):
await branch.ainvoke(1000, config={"callbacks": [tracer]})
assert len(tracer.runs) == 2
assert tracer.runs[1].error == "ValueError('x is too large')"
assert tracer.runs[1].outputs is None
@pytest.mark.asyncio
async def test_runnable_branch_abatch() -> None:
"""Test async variant of invoke."""
branch = RunnableBranch[int, int](
(lambda x: x > 0 and x < 5, lambda x: x + 1),
(lambda x: x > 5, lambda x: x * 10),
lambda x: x - 1,
)
assert await branch.abatch([1, 10, 0]) == [2, 100, -1]

Loading…
Cancel
Save