Pass kwargs in runnable retry (#11324)

pull/11326/head
Nuno Campos 11 months ago committed by GitHub
parent 8a507154ca
commit 0aedbcf7b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -89,12 +89,14 @@ class RunnableRetry(RunnableBinding[Input, Output]):
input: Input, input: Input,
run_manager: "CallbackManagerForChainRun", run_manager: "CallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any
) -> Output: ) -> Output:
for attempt in self._sync_retrying(reraise=True): for attempt in self._sync_retrying(reraise=True):
with attempt: with attempt:
result = super().invoke( result = super().invoke(
input, input,
self._patch_config(config, run_manager, attempt.retry_state), self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
@ -110,12 +112,14 @@ class RunnableRetry(RunnableBinding[Input, Output]):
input: Input, input: Input,
run_manager: "AsyncCallbackManagerForChainRun", run_manager: "AsyncCallbackManagerForChainRun",
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any
) -> Output: ) -> Output:
async for attempt in self._async_retrying(reraise=True): async for attempt in self._async_retrying(reraise=True):
with attempt: with attempt:
result = await super().ainvoke( result = await super().ainvoke(
input, input,
self._patch_config(config, run_manager, attempt.retry_state), self._patch_config(config, run_manager, attempt.retry_state),
**kwargs,
) )
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result) attempt.retry_state.set_result(result)
@ -131,6 +135,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
inputs: List[Input], inputs: List[Input],
run_manager: List["CallbackManagerForChainRun"], run_manager: List["CallbackManagerForChainRun"],
config: List[RunnableConfig], config: List[RunnableConfig],
**kwargs: Any
) -> List[Union[Output, Exception]]: ) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: Dict[int, Output] = {}
@ -147,6 +152,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
pending(config), pending(run_manager), attempt.retry_state pending(config), pending(run_manager), attempt.retry_state
), ),
return_exceptions=True, return_exceptions=True,
**kwargs,
) )
# Register the results of the inputs that have succeeded. # Register the results of the inputs that have succeeded.
first_exception = None first_exception = None
@ -195,6 +201,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
inputs: List[Input], inputs: List[Input],
run_manager: List["AsyncCallbackManagerForChainRun"], run_manager: List["AsyncCallbackManagerForChainRun"],
config: List[RunnableConfig], config: List[RunnableConfig],
**kwargs: Any
) -> List[Union[Output, Exception]]: ) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {} results_map: Dict[int, Output] = {}
@ -211,6 +218,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
pending(config), pending(run_manager), attempt.retry_state pending(config), pending(run_manager), attempt.retry_state
), ),
return_exceptions=True, return_exceptions=True,
**kwargs,
) )
# Register the results of the inputs that have succeeded. # Register the results of the inputs that have succeeded.
first_exception = None first_exception = None

Loading…
Cancel
Save