diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index e8a18eb219..51ccc58f9f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC): def with_retry( self, *, - retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,), + retry_if_exception_type: Tuple[Type[BaseException], ...] = (Exception,), wait_exponential_jitter: bool = True, stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: @@ -280,7 +280,7 @@ class Runnable(Generic[Input, Output], ABC): self, fallbacks: Sequence[Runnable[Input, Output]], *, - exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,), + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,), ) -> RunnableWithFallbacks[Input, Output]: return RunnableWithFallbacks( runnable=self, @@ -653,7 +653,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): runnable: Runnable[Input, Output] fallbacks: Sequence[Runnable[Input, Output]] - exceptions_to_handle: Tuple[Type[BaseException]] = (Exception,) + exceptions_to_handle: Tuple[Type[BaseException], ...] = (Exception,) class Config: arbitrary_types_allowed = True diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index 37de03f600..b41f74583b 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -24,7 +24,7 @@ U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry_exception_types: Tuple[Type[BaseException]] = (Exception,) + retry_exception_types: Tuple[Type[BaseException], ...] = (Exception,) wait_exponential_jitter: bool = True diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 8dd871ee4e..98bf284fd0 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1507,7 +1507,7 @@ async def test_async_retrying(mocker: MockerFixture) -> None: with pytest.raises(ValueError): await runnable.with_retry( stop_after_attempt=2, - retry_if_exception_type=(ValueError,), + retry_if_exception_type=(ValueError, KeyError), ).ainvoke(1) assert _lambda_mock.call_count == 2 # retried