Fix runnable branch callbacks (#11091)

We aren't calling on_chain_end here unless we use the default option
pull/11064/head^2
William FH 11 months ago committed by GitHub
parent 5514ebe859
commit 75b3893daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -867,20 +867,21 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
)
if expression_value:
return runnable.invoke(
output = runnable.invoke(
input,
config=patch_config(
config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
),
)
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
break
else:
output = self.default.invoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
)
except Exception as e:
run_manager.on_chain_error(e)
raise
@ -911,7 +912,7 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
)
if expression_value:
return await runnable.ainvoke(
output = await runnable.ainvoke(
input,
config=patch_config(
config,
@ -919,14 +920,15 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
),
**kwargs,
)
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
break
else:
output = await self.default.ainvoke(
input,
config=patch_config(
config, callbacks=run_manager.get_child(tag="branch:default")
),
**kwargs,
)
except Exception as e:
run_manager.on_chain_error(e)
raise

Loading…
Cancel
Save