diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index f428dcdf1d..fb000ae1e7 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -263,9 +263,9 @@ class Runnable(Generic[Input, Output], ABC): bound=self, kwargs={}, config={}, - retry_if_exception_type=retry_if_exception_type, + retry_exception_types=retry_if_exception_type, wait_exponential_jitter=wait_exponential_jitter, - stop_after_attempt=stop_after_attempt, + max_attempt_number=stop_after_attempt, ) def map(self) -> Runnable[List[Input], List[Output]]: @@ -388,6 +388,7 @@ class Runnable(Generic[Input, Output], ABC): ], input: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, return_exceptions: bool = False, run_type: Optional[str] = None, ) -> List[Output]: @@ -456,6 +457,7 @@ class Runnable(Generic[Input, Output], ABC): ], input: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, return_exceptions: bool = False, run_type: Optional[str] = None, ) -> List[Output]: @@ -1057,9 +1059,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # If an input has failed it will be present in this map, # and the value will be the exception that was raised. failed_inputs_map: Dict[int, Exception] = {} - stepidx = -1 for step in self.steps: - stepidx += 1 # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) remaining_idxs = [ @@ -1176,9 +1176,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # If an input has failed it will be present in this map, # and the value will be the exception that was raised. failed_inputs_map: Dict[int, Exception] = {} - stepidx = -1 for step in self.steps: - stepidx += 1 # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) remaining_idxs = [ diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index cda41605aa..37de03f600 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -24,31 +24,32 @@ U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,) + retry_exception_types: Tuple[Type[BaseException]] = (Exception,) wait_exponential_jitter: bool = True - stop_after_attempt: int = 3 + max_attempt_number: int = 3 + @property def _kwargs_retrying(self) -> Dict[str, Any]: kwargs: Dict[str, Any] = dict() - if self.stop_after_attempt: - kwargs["stop"] = stop_after_attempt(self.stop_after_attempt) + if self.max_attempt_number: + kwargs["stop"] = stop_after_attempt(self.max_attempt_number) if self.wait_exponential_jitter: kwargs["wait"] = wait_exponential_jitter() - if self.retry_if_exception_type: - kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type) + if self.retry_exception_types: + kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) return kwargs def _sync_retrying(self, **kwargs: Any) -> Retrying: - return Retrying(**self._kwargs_retrying(), **kwargs) + return Retrying(**self._kwargs_retrying, **kwargs) def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: - return AsyncRetrying(**self._kwargs_retrying(), **kwargs) + return AsyncRetrying(**self._kwargs_retrying, **kwargs) def _patch_config( self, @@ -56,15 +57,9 @@ class RunnableRetry(RunnableBinding[Input, Output]): run_manager: T, retry_state: RetryCallState, ) -> RunnableConfig: - config = config or {} - return patch_config( - config, - callbacks=run_manager.get_child( - "retry:attempt:{}".format(retry_state.attempt_number) - if retry_state.attempt_number > 1 - else None - ), - ) + attempt = retry_state.attempt_number + tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None + return patch_config(config, callbacks=run_manager.get_child(tag)) def _patch_config_list( self,