diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 88572bfee1..7b18062846 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -27,6 +27,8 @@ from typing import ( cast, ) +from tenacity import BaseRetrying + if TYPE_CHECKING: from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -226,6 +228,14 @@ class Runnable(Generic[Input, Output], ABC): bound=self, config={**(config or {}), **kwargs}, kwargs={} ) + def with_retry( + self, + retry: BaseRetrying, + ) -> Runnable[Input, Output]: + from langchain.schema.runnable.retry import RunnableRetry + + return RunnableRetry(bound=self, retry=retry, kwargs={}, config={}) + def map(self) -> Runnable[List[Input], List[Output]]: """ Return a new Runnable that maps a list of inputs to a list of outputs, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 3f87f04403..5752b09bf2 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -98,6 +98,7 @@ def patch_config( recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, + tags: Optional[List[str]] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -114,6 +115,8 @@ def patch_config( config["max_concurrency"] = max_concurrency if run_name is not None: config["run_name"] = run_name + if tags is not None: + config["tags"] = tags return config diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py new file mode 100644 index 0000000000..b899f33d81 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional, Union +from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding +from langchain.schema.runnable.config import RunnableConfig, patch_config +from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying + + +class RunnableRetry(RunnableBinding[Input, Output]): + """Retry a Runnable if it fails.""" + + retry: BaseRetrying + + def _sync_retrying(self) -> Retrying: + return Retrying( + sleep=self.retry.sleep, + stop=self.retry.stop, + wait=self.retry.wait, + retry=self.retry.retry, + before=self.retry.before, + after=self.retry.after, + before_sleep=self.retry.before_sleep, + reraise=self.retry.reraise, + retry_error_cls=self.retry.retry_error_cls, + retry_error_callback=self.retry.retry_error_callback, + ) + + def _async_retrying(self) -> AsyncRetrying: + return AsyncRetrying( + sleep=self.retry.sleep, + stop=self.retry.stop, + wait=self.retry.wait, + retry=self.retry.retry, + before=self.retry.before, + after=self.retry.after, + before_sleep=self.retry.before_sleep, + reraise=self.retry.reraise, + retry_error_cls=self.retry.retry_error_cls, + retry_error_callback=self.retry.retry_error_callback, + ) + + def _patch_config( + self, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + retry_state: RetryCallState, + ) -> RunnableConfig: + if isinstance(config, list): + return [self._patch_config(c, retry_state) for c in config] + + config = config or {} + original_tags = config.get("tags") or [] + return patch_config( + config, + tags=original_tags + + ["retry:attempt:{}".format(retry_state.attempt_number)], + ) + + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None + ) -> Output: + for attempt in self._sync_retrying(): + with attempt: + result = super().invoke( + input, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + async def ainvoke( + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None + ) -> Output: + async for attempt in self._async_retrying(): + with attempt: + result = await super().ainvoke( + input, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + **kwargs: Any + ) -> List[Output]: + for attempt in self._sync_retrying(): + with attempt: + result = super().batch( + inputs, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + **kwargs: Any + ) -> List[Output]: + async for attempt in self._async_retrying(): + with attempt: + result = await super().abatch( + inputs, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 412fa8e1e7..10a1ac00e3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -41,6 +41,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt class FakeTracer(BaseTracer): @@ -141,7 +142,7 @@ async def test_with_config(mocker: MockerFixture) -> None: else: assert call.args[2].get("tags") == ["b-tag"] assert call.args[2].get("max_concurrency") == 5 - spy_seq_step.reset_mock() + mocker.stop(spy_seq_step) assert [ *fake.with_config(tags=["a-tag"]).stream( @@ -1423,3 +1424,42 @@ def test_recursive_lambda() -> None: with pytest.raises(RecursionError): runnable.invoke(0, {"recursion_limit": 9}) + + +def test_retrying(mocker: MockerFixture) -> None: + def _lambda(x: int) -> Union[int, Runnable]: + if x == 1: + raise ValueError("x is 1") + elif x == 2: + raise RuntimeError("x is 2") + else: + return x + + _lambda_mock = mocker.Mock(side_effect=_lambda) + runnable = RunnableLambda(_lambda_mock) + + with pytest.raises(ValueError): + runnable.invoke(1) + + assert _lambda_mock.call_count == 1 + _lambda_mock.reset_mock() + + with pytest.raises(RetryError): + runnable.with_retry( + Retrying( + stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) + ) + ).invoke(1) + + assert _lambda_mock.call_count == 2 + _lambda_mock.reset_mock() + + with pytest.raises(RuntimeError): + runnable.with_retry( + Retrying( + stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) + ) + ).invoke(2) + + assert _lambda_mock.call_count == 1 + _lambda_mock.reset_mock()