Fix runnable branch callbacks (#11091)

We aren't calling on_chain_end here unless we use the default option
pull/11064/head^2
William FH 12 months ago committed by GitHub
parent 5514ebe859
commit 75b3893daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -867,20 +867,21 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
) )
if expression_value: if expression_value:
return runnable.invoke( output = runnable.invoke(
input, input,
config=patch_config( config=patch_config(
config, config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
), ),
) )
break
output = self.default.invoke( else:
input, output = self.default.invoke(
config=patch_config( input,
config, callbacks=run_manager.get_child(tag="branch:default") config=patch_config(
), config, callbacks=run_manager.get_child(tag="branch:default")
) ),
)
except Exception as e: except Exception as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
@ -911,7 +912,7 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
) )
if expression_value: if expression_value:
return await runnable.ainvoke( output = await runnable.ainvoke(
input, input,
config=patch_config( config=patch_config(
config, config,
@ -919,14 +920,15 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
), ),
**kwargs, **kwargs,
) )
break
output = await self.default.ainvoke( else:
input, output = await self.default.ainvoke(
config=patch_config( input,
config, callbacks=run_manager.get_child(tag="branch:default") config=patch_config(
), config, callbacks=run_manager.get_child(tag="branch:default")
**kwargs, ),
) **kwargs,
)
except Exception as e: except Exception as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise

Loading…
Cancel
Save