From 1b0eebe1e3fa792f4187b9e0f5aebf03cea89f95 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:07:15 -0700 Subject: [PATCH] Support multiple errors (#10376) in on_retry --- libs/langchain/langchain/schema/runnable/base.py | 6 +++--- libs/langchain/langchain/schema/runnable/retry.py | 2 +- .../tests/unit_tests/schema/runnable/test_runnable.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index e8a18eb219..51ccc58f9f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC): def with_retry( self, *, - retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,), + retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), wait_exponential_jitter: bool = True, stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: @@ -280,7 +280,7 @@ class Runnable(Generic[Input, Output], ABC): self, fallbacks: Sequence[Runnable[Input, Output]], *, - exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,), + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), ) -> RunnableWithFallbacks[Input, Output]: return RunnableWithFallbacks( runnable=self, @@ -653,7 +653,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): runnable: Runnable[Input, Output] fallbacks: Sequence[Runnable[Input, Output]] - exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,) + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) class Config: arbitrary_types_allowed = True diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index 37de03f600..b41f74583b 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -24,7 +24,7 @@ U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry_exception_types: Tuple[Type[BaseException]] = (Exception,) + retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) wait_exponential_jitter: bool = True diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 8dd871ee4e..98bf284fd0 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1507,7 +1507,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None: with pytest.raises(ValueError): await runnable.with_retry( stop_after_attempt=2, - retry_if_exception_type=(ValueError,), + retry_if_exception_type=(ValueError, KeyError), ).ainvoke(1) assert _lambda_mock.call_count == 2 # retried