diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index b899f33d81..1c0a2b7882 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,50 +1,51 @@ -from typing import Any, List, Optional, Union -from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding -from langchain.schema.runnable.config import RunnableConfig, patch_config +from typing import Any, Dict, List, Optional, Union + from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying +from langchain.schema.runnable.base import Input, Output, RunnableBinding +from langchain.schema.runnable.config import RunnableConfig, patch_config + class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" retry: BaseRetrying + def _kwargs_retrying(self) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + if self.retry.sleep is not None: + kwargs["sleep"] = self.retry.sleep + if self.retry.stop is not None: + kwargs["stop"] = self.retry.stop + if self.retry.wait is not None: + kwargs["wait"] = self.retry.wait + if self.retry.retry is not None: + kwargs["retry"] = self.retry.retry + if self.retry.before is not None: + kwargs["before"] = self.retry.before + if self.retry.after is not None: + kwargs["after"] = self.retry.after + if self.retry.before_sleep is not None: + kwargs["before_sleep"] = self.retry.before_sleep + if self.retry.reraise is not None: + kwargs["reraise"] = self.retry.reraise + if self.retry.retry_error_cls is not None: + kwargs["retry_error_cls"] = self.retry.retry_error_cls + if self.retry.retry_error_callback is not None: + kwargs["retry_error_callback"] = self.retry.retry_error_callback + return kwargs + def _sync_retrying(self) -> Retrying: - return Retrying( - sleep=self.retry.sleep, - stop=self.retry.stop, - wait=self.retry.wait, - retry=self.retry.retry, - before=self.retry.before, - after=self.retry.after, - before_sleep=self.retry.before_sleep, - reraise=self.retry.reraise, - retry_error_cls=self.retry.retry_error_cls, - retry_error_callback=self.retry.retry_error_callback, - ) + return Retrying(**self._kwargs_retrying()) def _async_retrying(self) -> AsyncRetrying: - return AsyncRetrying( - sleep=self.retry.sleep, - stop=self.retry.stop, - wait=self.retry.wait, - retry=self.retry.retry, - before=self.retry.before, - after=self.retry.after, - before_sleep=self.retry.before_sleep, - reraise=self.retry.reraise, - retry_error_cls=self.retry.retry_error_cls, - retry_error_callback=self.retry.retry_error_callback, - ) + return AsyncRetrying(**self._kwargs_retrying()) def _patch_config( self, - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + config: Optional[RunnableConfig], retry_state: RetryCallState, ) -> RunnableConfig: - if isinstance(config, list): - return [self._patch_config(c, retry_state) for c in config] - config = config or {} original_tags = config.get("tags") or [] return patch_config( @@ -53,6 +54,16 @@ class RunnableRetry(RunnableBinding[Input, Output]): + ["retry:attempt:{}".format(retry_state.attempt_number)], ) + def _patch_config_list( + self, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + retry_state: RetryCallState, + ) -> Union[RunnableConfig, List[RunnableConfig]]: + if isinstance(config, list): + return [self._patch_config(c, retry_state) for c in config] + + return self._patch_config(config, retry_state) + def invoke( self, input: Input, @@ -64,7 +75,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): result = super().invoke( input, self._patch_config(config, attempt.retry_state), **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -76,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): result = await super().ainvoke( input, self._patch_config(config, attempt.retry_state), **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -89,9 +100,11 @@ class RunnableRetry(RunnableBinding[Input, Output]): for attempt in self._sync_retrying(): with attempt: result = super().batch( - inputs, self._patch_config(config, attempt.retry_state), **kwargs + inputs, + self._patch_config_list(config, attempt.retry_state), + **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -104,8 +117,13 @@ class RunnableRetry(RunnableBinding[Input, Output]): async for attempt in self._async_retrying(): with attempt: result = await super().abatch( - inputs, self._patch_config(config, attempt.retry_state), **kwargs + inputs, + self._patch_config_list(config, attempt.retry_state), + **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result + + # stream() and transform() are not retried because retrying a stream + # is not very intuitive. diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 10a1ac00e3..eca512d0f6 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -6,6 +6,7 @@ import pytest from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion +from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks