diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index 80ff8a88a0..fa60a58f16 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -1,7 +1,9 @@ from typing import ( Any, + AsyncIterator, Awaitable, Callable, + Iterator, List, Mapping, Optional, @@ -23,6 +25,7 @@ from langchain_core.runnables.base import ( from langchain_core.runnables.config import ( RunnableConfig, ensure_config, + get_async_callback_manager_for_config, get_callback_manager_for_config, patch_config, ) @@ -212,7 +215,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ), **kwargs, ) - except Exception as e: + except BaseException as e: run_manager.on_chain_error(e) raise run_manager.on_chain_end(dumpd(output)) @@ -223,8 +226,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ) -> Output: """Async version of invoke.""" config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( dumpd(self), input, name=config.get("run_name"), @@ -259,8 +262,156 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ), **kwargs, ) - except Exception as e: + except BaseException as e: + await run_manager.on_chain_error(e) + raise + await run_manager.on_chain_end(dumpd(output)) + return output + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> Iterator[Output]: + """First evaluates the condition, + then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_callback_manager_for_config(config) + run_manager = callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + final_output: Optional[Output] = None + final_output_supported = True + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = condition.invoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + for chunk in runnable.stream( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ): + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + break + else: + for chunk in self.default.stream( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag="branch:default"), + ), + **kwargs, + ): + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + except BaseException as e: run_manager.on_chain_error(e) raise - run_manager.on_chain_end(dumpd(output)) - return output + run_manager.on_chain_end(final_output) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + """First evaluates the condition, + then delegate to true or false branch.""" + config = ensure_config(config) + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( + dumpd(self), + input, + name=config.get("run_name"), + ) + final_output: Optional[Output] = None + final_output_supported = True + + try: + for idx, branch in enumerate(self.branches): + condition, runnable = branch + + expression_value = await condition.ainvoke( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"condition:{idx + 1}"), + ), + ) + + if expression_value: + async for chunk in runnable.astream( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), + ), + **kwargs, + ): + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + break + else: + async for chunk in self.default.astream( + input, + config=patch_config( + config, + callbacks=run_manager.get_child(tag="branch:default"), + ), + **kwargs, + ): + yield chunk + if final_output_supported: + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = None + final_output_supported = False + except BaseException as e: + await run_manager.on_chain_error(e) + raise + await run_manager.on_chain_end(final_output) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 7125b5a37b..345a212920 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3981,6 +3981,140 @@ async def test_runnable_branch_abatch() -> None: assert await branch.abatch([1, 10, 0]) == [2, 100, -1] +def test_runnable_branch_stream() -> None: + """Verify that stream works for RunnableBranch.""" + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + branch = RunnableBranch[str, Any]( + (lambda x: x == "hello", llm), + lambda x: x, + ) + + assert list(branch.stream("hello")) == list(llm_res) + assert list(branch.stream("bye")) == ["bye"] + + +def test_runnable_branch_stream_with_callbacks() -> None: + """Verify that stream works for RunnableBranch when using callbacks.""" + tracer = FakeTracer() + + def raise_value_error(x: str) -> Any: + """Raise a value error.""" + raise ValueError(f"x is {x}") + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + branch = RunnableBranch[str, Any]( + (lambda x: x == "error", raise_value_error), + (lambda x: x == "hello", llm), + lambda x: x, + ) + config: RunnableConfig = {"callbacks": [tracer]} + + assert list(branch.stream("hello", config=config)) == list(llm_res) + + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": llm_res} + + # Verify that the chain on error is invoked + with pytest.raises(ValueError): + for _ in branch.stream("error", config=config): + pass + + assert len(tracer.runs) == 2 + assert "ValueError('x is error')" in str(tracer.runs[1].error) + assert tracer.runs[1].outputs is None + + assert list(branch.stream("bye", config=config)) == ["bye"] + + assert len(tracer.runs) == 3 + assert tracer.runs[2].error is None + assert tracer.runs[2].outputs == {"output": "bye"} + + +async def test_runnable_branch_astream() -> None: + """Verify that astream works for RunnableBranch.""" + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + branch = RunnableBranch[str, Any]( + (lambda x: x == "hello", llm), + lambda x: x, + ) + + assert [_ async for _ in branch.astream("hello")] == list(llm_res) + assert [_ async for _ in branch.astream("bye")] == ["bye"] + + # Verify that the async variant is used if available + async def condition(x: str) -> bool: + return x == "hello" + + async def repeat(x: str) -> str: + return x + x + + async def reverse(x: str) -> str: + return x[::-1] + + branch = RunnableBranch[str, Any]((condition, repeat), llm) + + assert [_ async for _ in branch.astream("hello")] == ["hello" * 2] + assert [_ async for _ in branch.astream("bye")] == list(llm_res) + + branch = RunnableBranch[str, Any]((condition, llm), reverse) + + assert [_ async for _ in branch.astream("hello")] == list(llm_res) + assert [_ async for _ in branch.astream("bye")] == ["eyb"] + + +async def test_runnable_branch_astream_with_callbacks() -> None: + """Verify that astream works for RunnableBranch when using callbacks.""" + tracer = FakeTracer() + + def raise_value_error(x: str) -> Any: + """Raise a value error.""" + raise ValueError(f"x is {x}") + + llm_res = "i'm a textbot" + # sleep to better simulate a real stream + llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) + + branch = RunnableBranch[str, Any]( + (lambda x: x == "error", raise_value_error), + (lambda x: x == "hello", llm), + lambda x: x, + ) + config: RunnableConfig = {"callbacks": [tracer]} + + assert [_ async for _ in branch.astream("hello", config=config)] == list(llm_res) + + assert len(tracer.runs) == 1 + assert tracer.runs[0].error is None + assert tracer.runs[0].outputs == {"output": llm_res} + + # Verify that the chain on error is invoked + with pytest.raises(ValueError): + async for _ in branch.astream("error", config=config): + pass + + assert len(tracer.runs) == 2 + assert "ValueError('x is error')" in str(tracer.runs[1].error) + assert tracer.runs[1].outputs is None + + assert [_ async for _ in branch.astream("bye", config=config)] == ["bye"] + + assert len(tracer.runs) == 3 + assert tracer.runs[2].error is None + assert tracer.runs[2].outputs == {"output": "bye"} + + @pytest.mark.skipif( sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." )