diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index b981affa0b..e76e16892a 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -867,20 +867,21 @@ class RunnableBranch(Serializable, Runnable[Input, Output]): ) if expression_value: - return runnable.invoke( + output = runnable.invoke( input, config=patch_config( config, callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), ), ) - - output = self.default.invoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - ) + break + else: + output = self.default.invoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + ) except Exception as e: run_manager.on_chain_error(e) raise @@ -911,7 +912,7 @@ class RunnableBranch(Serializable, Runnable[Input, Output]): ) if expression_value: - return await runnable.ainvoke( + output = await runnable.ainvoke( input, config=patch_config( config, @@ -919,14 +920,15 @@ class RunnableBranch(Serializable, Runnable[Input, Output]): ), **kwargs, ) - - output = await self.default.ainvoke( - input, - config=patch_config( - config, callbacks=run_manager.get_child(tag="branch:default") - ), - **kwargs, - ) + break + else: + output = await self.default.ainvoke( + input, + config=patch_config( + config, callbacks=run_manager.get_child(tag="branch:default") + ), + **kwargs, + ) except Exception as e: run_manager.on_chain_error(e) raise