diff --git a/libs/langchain/langchain/schema/runnable/__init__.py b/libs/langchain/langchain/schema/runnable/__init__.py index 24b235d4d8..2b068d5eba 100644 --- a/libs/langchain/langchain/schema/runnable/__init__.py +++ b/libs/langchain/langchain/schema/runnable/__init__.py @@ -2,6 +2,7 @@ from langchain.schema.runnable._locals import GetLocalVar, PutLocalVar from langchain.schema.runnable.base import ( Runnable, RunnableBinding, + RunnableBranch, RunnableLambda, RunnableMap, RunnableSequence, @@ -12,16 +13,17 @@ from langchain.schema.runnable.passthrough import RunnablePassthrough from langchain.schema.runnable.router import RouterInput, RouterRunnable __all__ = [ - "patch_config", "GetLocalVar", + "patch_config", "PutLocalVar", "RouterInput", "RouterRunnable", "Runnable", "RunnableBinding", + "RunnableBranch", "RunnableConfig", - "RunnableMap", "RunnableLambda", + "RunnableMap", "RunnablePassthrough", "RunnableSequence", "RunnableWithFallbacks", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 3c879ef8c5..4650c2dab6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -658,6 +658,188 @@ class Runnable(Generic[Input, Output], ABC): await run_manager.on_chain_end(final_output, inputs=final_input) +class RunnableBranch(Serializable, Runnable[Input, Output]): + """A Runnable that selects which branch to run based on a condition. + + The runnable is initialized with a list of (condition, runnable) pairs and + a default branch. + + When operating on an input, the first condition that evaluates to True is + selected, and the corresponding runnable is run on the input. + + If no condition evaluates to True, the default branch is run on the input. + + Examples: + + .. code-block:: python + + from langchain.schema.runnable import RunnableBranch + + branch = RunnableBranch( + (lambda x: isinstance(x, str), lambda x: x.upper()), + (lambda x: isinstance(x, int), lambda x: x + 1), + (lambda x: isinstance(x, float), lambda x: x * 2), + lambda x: "goodbye", + ) + + branch.invoke("hello") # "HELLO" + branch.invoke(None) # "goodbye" + """ + + branches: Sequence[Tuple[Runnable[Input, bool], Runnable[Input, Output]]] + default: Runnable[Input, Output] + + def __init__( + self, + *branches: Union[ + Tuple[ + Union[ + Runnable[Input, bool], + Callable[[Input], bool], + Callable[[Input], Awaitable[bool]], + ], + RunnableLike, + ], + RunnableLike, # To accommodate the default branch + ], + ) -> None: + """A Runnable that runs one of two branches based on a condition.""" + if len(branches) < 2: + raise ValueError("RunnableBranch requires at least two branches") + + default = branches[-1] + + if not isinstance( + default, (Runnable, Callable, Mapping) # type: ignore[arg-type] + ): + raise TypeError( + "RunnableBranch default must be runnable, callable or mapping." + ) + + default_ = cast( + Runnable[Input, Output], coerce_to_runnable(cast(RunnableLike, default)) + ) + + _branches = [] + + for branch in branches[:-1]: + if not isinstance(branch, (tuple, list)): # type: ignore[arg-type] + raise TypeError( + f"RunnableBranch branches must be " + f"tuples or lists, not {type(branch)}" + ) + + if not len(branch) == 2: + raise ValueError( + f"RunnableBranch branches must be " + f"tuples or lists of length 2, not {len(branch)}" + ) + condition, runnable = branch + condition = cast(Runnable[Input, bool], coerce_to_runnable(condition)) + runnable = coerce_to_runnable(runnable) + _branches.append((condition, runnable)) + + super().__init__(branches=_branches, default=default_) + + class Config: + arbitrary_types_allowed = True + + @property + def lc_serializable(self) -> bool: + """RunnableBranch is serializable if all its branches are serializable.""" + return True + + @property + def lc_namespace(self) -> List[str]: + """The namespace of a RunnableBranch is the namespace of its default branch.""" + return self.__class__.__module__.split(".")[:-1] + + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + """First evaluates the condition, then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = condition.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag=f"condition:{idx}") + ), + ) + + if expression_value: + return runnable.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag=f"branch:{idx}") + ), + ) + + output = self.default.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + """Async version of invoke.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = await condition.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag=f"condition:{idx}") + ), + ) + + if expression_value: + return await runnable.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag=f"branch:{idx}") + ), + **kwargs, + ) + + output = await self.default.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) + except Exception as e: + run_manager.on_chain_error(e) + raise + run_manager.on_chain_end(dumpd(output)) + return output + + class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): """ A Runnable that can fallback to other Runnables if it fails. @@ -2007,14 +2189,15 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig) +RunnableLike = Union[ + Runnable[Input, Output], + Callable[[Input], Output], + Callable[[Input], Awaitable[Output]], + Mapping[str, Any], +] + -def coerce_to_runnable( - thing: Union[ - Runnable[Input, Output], - Callable[[Input], Output], - Mapping[str, Any], - ] -) -> Runnable[Input, Output]: +def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: if isinstance(thing, Runnable): return thing elif callable(thing): diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a6b8db9ff2..45fbe1ca40 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union from uuid import UUID import pytest @@ -34,6 +34,7 @@ from langchain.schema.retriever import BaseRetriever from langchain.schema.runnable import ( RouterRunnable, Runnable, + RunnableBranch, RunnableConfig, RunnableLambda, RunnableMap, @@ -541,7 +542,7 @@ async def test_prompt_with_llm( mocker.stop(prompt_spy) mocker.stop(llm_spy) - # Test stream + # Test stream# prompt_spy = mocker.spy(prompt.__class__, "ainvoke") llm_spy = mocker.spy(llm.__class__, "astream") tracer = FakeTracer() @@ -1816,3 +1817,205 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: assert parent_run_qux.outputs["output"] == "quxaaaa" assert len(parent_run_qux.child_runs) == 4 assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None] + + +def test_runnable_branch_init() -> None: + """Verify that runnable branch gets initialized properly.""" + add = RunnableLambda(lambda x: x + 1) + condition = RunnableLambda(lambda x: x > 0) + + # Test failure with less than 2 branches + with pytest.raises(ValueError): + RunnableBranch((condition, add)) + + # Test failure with less than 2 branches + with pytest.raises(ValueError): + RunnableBranch(condition) + + +@pytest.mark.parametrize( + "branches", + [ + [ + (RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)), + RunnableLambda(lambda x: x - 1), + ], + [ + (RunnableLambda(lambda x: x > 0), RunnableLambda(lambda x: x + 1)), + (RunnableLambda(lambda x: x > 5), RunnableLambda(lambda x: x + 1)), + RunnableLambda(lambda x: x - 1), + ], + [ + (lambda x: x > 0, lambda x: x + 1), + (lambda x: x > 5, lambda x: x + 1), + lambda x: x - 1, + ], + ], +) +def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None: + """Verify that runnable branch gets initialized properly.""" + runnable = RunnableBranch[int, int](*branches) + for branch in runnable.branches: + condition, body = branch + assert isinstance(condition, Runnable) + assert isinstance(body, Runnable) + + assert isinstance(runnable.default, Runnable) + + +def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None: + """Verify that runnables are invoked only when necessary.""" + # Test with single branch + add = RunnableLambda(lambda x: x + 1) + sub = RunnableLambda(lambda x: x - 1) + condition = RunnableLambda(lambda x: x > 0) + spy = mocker.spy(condition, "invoke") + add_spy = mocker.spy(add, "invoke") + + branch = RunnableBranch[int, int]((condition, add), (condition, add), sub) + assert spy.call_count == 0 + assert add_spy.call_count == 0 + + assert branch.invoke(1) == 2 + assert add_spy.call_count == 1 + assert spy.call_count == 1 + + assert branch.invoke(2) == 3 + assert spy.call_count == 2 + assert add_spy.call_count == 2 + + assert branch.invoke(-3) == -4 + # Should fall through to default branch with condition being evaluated twice! + assert spy.call_count == 4 + # Add should not be invoked + assert add_spy.call_count == 2 + + +def test_runnable_branch_invoke() -> None: + # Test with single branch + def raise_value_error(x: int) -> int: + """Raise a value error.""" + raise ValueError("x is too large") + + branch = RunnableBranch[int, int]( + (lambda x: x > 100, raise_value_error), + # mypy cannot infer types from the lambda + (lambda x: x > 0 and x < 5, lambda x: x + 1), # type: ignore[misc] + (lambda x: x > 5, lambda x: x * 10), + lambda x: x - 1, + ) + + assert branch.invoke(1) == 2 + assert branch.invoke(10) == 100 + assert branch.invoke(0) == -1 + # Should raise an exception + with pytest.raises(ValueError): + branch.invoke(1000) + + +def test_runnable_branch_batch() -> None: + """Test batch variant.""" + # Test with single branch + branch = RunnableBranch[int, int]( + (lambda x: x > 0 and x < 5, lambda x: x + 1), + (lambda x: x > 5, lambda x: x * 10), + lambda x: x - 1, + ) + + assert branch.batch([1, 10, 0]) == [2, 100, -1] + + +@pytest.mark.asyncio +async def test_runnable_branch_ainvoke() -> None: + """Test async variant of invoke.""" + branch = RunnableBranch[int, int]( + (lambda x: x > 0 and x < 5, lambda x: x + 1), + (lambda x: x > 5, lambda x: x * 10), + lambda x: x - 1, + ) + + assert await branch.ainvoke(1) == 2 + assert await branch.ainvoke(10) == 100 + assert await branch.ainvoke(0) == -1 + + # Verify that the async variant is used if available + async def condition(x: int) -> bool: + return x > 0 + + async def add(x: int) -> int: + return x + 1 + + async def sub(x: int) -> int: + return x - 1 + + branch = RunnableBranch[int, int]((condition, add), sub) + + assert await branch.ainvoke(1) == 2 + assert await branch.ainvoke(-10) == -11 + + +def test_runnable_branch_invoke_callbacks() -> None: + """Verify that callbacks are correctly used in invoke.""" + tracer = FakeTracer() + + def raise_value_error(x: int) -> int: + """Raise a value error.""" + raise ValueError("x is too large") + + branch = RunnableBranch[int, int]( + (lambda x: x > 100, raise_value_error), + lambda x: x - 1, + ) + + assert branch.invoke(1, config={"callbacks": [tracer]}) == 0 + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": 0} + + # Check that the chain on end is invoked + with pytest.raises(ValueError): + branch.invoke(1000, config={"callbacks": [tracer]}) + + assert len(tracer.runs) == 2 + assert tracer.runs[1].error == "ValueError('x is too large')" + assert tracer.runs[1].outputs is None + + +@pytest.mark.asyncio +async def test_runnable_branch_ainvoke_callbacks() -> None: + """Verify that callbacks are invoked correctly in ainvoke.""" + tracer = FakeTracer() + + async def raise_value_error(x: int) -> int: + """Raise a value error.""" + raise ValueError("x is too large") + + branch = RunnableBranch[int, int]( + (lambda x: x > 100, raise_value_error), + lambda x: x - 1, + ) + + assert await branch.ainvoke(1, config={"callbacks": [tracer]}) == 0 + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": 0} + + # Check that the chain on end is invoked + with pytest.raises(ValueError): + await branch.ainvoke(1000, config={"callbacks": [tracer]}) + + assert len(tracer.runs) == 2 + assert tracer.runs[1].error == "ValueError('x is too large')" + assert tracer.runs[1].outputs is None + + +@pytest.mark.asyncio +async def test_runnable_branch_abatch() -> None: + """Test async variant of invoke.""" + branch = RunnableBranch[int, int]( + (lambda x: x > 0 and x < 5, lambda x: x + 1), + (lambda x: x > 5, lambda x: x * 10), + lambda x: x - 1, + ) + + assert await branch.abatch([1, 10, 0]) == [2, 100, -1]