From b2ac8354669fadda26662133f13a03943f7caacd Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:05:23 +0200 Subject: [PATCH 1/9] Add .with_retry() to Runnables --- .../langchain/schema/runnable/base.py | 10 ++ .../langchain/schema/runnable/config.py | 3 + .../langchain/schema/runnable/retry.py | 111 ++++++++++++++++++ .../schema/runnable/test_runnable.py | 42 ++++++- 4 files changed, 165 insertions(+), 1 deletion(-) create mode 100644 libs/langchain/langchain/schema/runnable/retry.py diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 88572bfee1..7b18062846 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -27,6 +27,8 @@ from typing import ( cast, ) +from tenacity import BaseRetrying + if TYPE_CHECKING: from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -226,6 +228,14 @@ class Runnable(Generic[Input, Output], ABC): bound=self, config={**(config or {}), **kwargs}, kwargs={} ) + def with_retry( + self, + retry: BaseRetrying, + ) -> Runnable[Input, Output]: + from langchain.schema.runnable.retry import RunnableRetry + + return RunnableRetry(bound=self, retry=retry, kwargs={}, config={}) + def map(self) -> Runnable[List[Input], List[Output]]: """ Return a new Runnable that maps a list of inputs to a list of outputs, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 3f87f04403..5752b09bf2 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -98,6 +98,7 @@ def patch_config( recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, + tags: Optional[List[str]] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -114,6 +115,8 @@ def patch_config( config["max_concurrency"] = max_concurrency if run_name is not None: config["run_name"] = run_name + if tags is not None: + config["tags"] = tags return config diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py new file mode 100644 index 0000000000..b899f33d81 --- /dev/null +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -0,0 +1,111 @@ +from typing import Any, List, Optional, Union +from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding +from langchain.schema.runnable.config import RunnableConfig, patch_config +from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying + + +class RunnableRetry(RunnableBinding[Input, Output]): + """Retry a Runnable if it fails.""" + + retry: BaseRetrying + + def _sync_retrying(self) -> Retrying: + return Retrying( + sleep=self.retry.sleep, + stop=self.retry.stop, + wait=self.retry.wait, + retry=self.retry.retry, + before=self.retry.before, + after=self.retry.after, + before_sleep=self.retry.before_sleep, + reraise=self.retry.reraise, + retry_error_cls=self.retry.retry_error_cls, + retry_error_callback=self.retry.retry_error_callback, + ) + + def _async_retrying(self) -> AsyncRetrying: + return AsyncRetrying( + sleep=self.retry.sleep, + stop=self.retry.stop, + wait=self.retry.wait, + retry=self.retry.retry, + before=self.retry.before, + after=self.retry.after, + before_sleep=self.retry.before_sleep, + reraise=self.retry.reraise, + retry_error_cls=self.retry.retry_error_cls, + retry_error_callback=self.retry.retry_error_callback, + ) + + def _patch_config( + self, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + retry_state: RetryCallState, + ) -> RunnableConfig: + if isinstance(config, list): + return [self._patch_config(c, retry_state) for c in config] + + config = config or {} + original_tags = config.get("tags") or [] + return patch_config( + config, + tags=original_tags + + ["retry:attempt:{}".format(retry_state.attempt_number)], + ) + + def invoke( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None + ) -> Output: + for attempt in self._sync_retrying(): + with attempt: + result = super().invoke( + input, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + async def ainvoke( + self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None + ) -> Output: + async for attempt in self._async_retrying(): + with attempt: + result = await super().ainvoke( + input, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + def batch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + **kwargs: Any + ) -> List[Output]: + for attempt in self._sync_retrying(): + with attempt: + result = super().batch( + inputs, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result + + async def abatch( + self, + inputs: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + **kwargs: Any + ) -> List[Output]: + async for attempt in self._async_retrying(): + with attempt: + result = await super().abatch( + inputs, self._patch_config(config, attempt.retry_state), **kwargs + ) + if not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 412fa8e1e7..10a1ac00e3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -41,6 +41,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt class FakeTracer(BaseTracer): @@ -141,7 +142,7 @@ async def test_with_config(mocker: MockerFixture) -> None: else: assert call.args[2].get("tags") == ["b-tag"] assert call.args[2].get("max_concurrency") == 5 - spy_seq_step.reset_mock() + mocker.stop(spy_seq_step) assert [ *fake.with_config(tags=["a-tag"]).stream( @@ -1423,3 +1424,42 @@ def test_recursive_lambda() -> None: with pytest.raises(RecursionError): runnable.invoke(0, {"recursion_limit": 9}) + + +def test_retrying(mocker: MockerFixture) -> None: + def _lambda(x: int) -> Union[int, Runnable]: + if x == 1: + raise ValueError("x is 1") + elif x == 2: + raise RuntimeError("x is 2") + else: + return x + + _lambda_mock = mocker.Mock(side_effect=_lambda) + runnable = RunnableLambda(_lambda_mock) + + with pytest.raises(ValueError): + runnable.invoke(1) + + assert _lambda_mock.call_count == 1 + _lambda_mock.reset_mock() + + with pytest.raises(RetryError): + runnable.with_retry( + Retrying( + stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) + ) + ).invoke(1) + + assert _lambda_mock.call_count == 2 + _lambda_mock.reset_mock() + + with pytest.raises(RuntimeError): + runnable.with_retry( + Retrying( + stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) + ) + ).invoke(2) + + assert _lambda_mock.call_count == 1 + _lambda_mock.reset_mock() From 2242e2160fd72756da4bd2a9f987f3aa2cae3fe0 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:24:13 +0200 Subject: [PATCH 2/9] Lint --- .../langchain/schema/runnable/retry.py | 92 +++++++++++-------- .../schema/runnable/test_runnable.py | 1 + 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index b899f33d81..1c0a2b7882 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,50 +1,51 @@ -from typing import Any, List, Optional, Union -from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding -from langchain.schema.runnable.config import RunnableConfig, patch_config +from typing import Any, Dict, List, Optional, Union + from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying +from langchain.schema.runnable.base import Input, Output, RunnableBinding +from langchain.schema.runnable.config import RunnableConfig, patch_config + class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" retry: BaseRetrying + def _kwargs_retrying(self) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + if self.retry.sleep is not None: + kwargs["sleep"] = self.retry.sleep + if self.retry.stop is not None: + kwargs["stop"] = self.retry.stop + if self.retry.wait is not None: + kwargs["wait"] = self.retry.wait + if self.retry.retry is not None: + kwargs["retry"] = self.retry.retry + if self.retry.before is not None: + kwargs["before"] = self.retry.before + if self.retry.after is not None: + kwargs["after"] = self.retry.after + if self.retry.before_sleep is not None: + kwargs["before_sleep"] = self.retry.before_sleep + if self.retry.reraise is not None: + kwargs["reraise"] = self.retry.reraise + if self.retry.retry_error_cls is not None: + kwargs["retry_error_cls"] = self.retry.retry_error_cls + if self.retry.retry_error_callback is not None: + kwargs["retry_error_callback"] = self.retry.retry_error_callback + return kwargs + def _sync_retrying(self) -> Retrying: - return Retrying( - sleep=self.retry.sleep, - stop=self.retry.stop, - wait=self.retry.wait, - retry=self.retry.retry, - before=self.retry.before, - after=self.retry.after, - before_sleep=self.retry.before_sleep, - reraise=self.retry.reraise, - retry_error_cls=self.retry.retry_error_cls, - retry_error_callback=self.retry.retry_error_callback, - ) + return Retrying(**self._kwargs_retrying()) def _async_retrying(self) -> AsyncRetrying: - return AsyncRetrying( - sleep=self.retry.sleep, - stop=self.retry.stop, - wait=self.retry.wait, - retry=self.retry.retry, - before=self.retry.before, - after=self.retry.after, - before_sleep=self.retry.before_sleep, - reraise=self.retry.reraise, - retry_error_cls=self.retry.retry_error_cls, - retry_error_callback=self.retry.retry_error_callback, - ) + return AsyncRetrying(**self._kwargs_retrying()) def _patch_config( self, - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + config: Optional[RunnableConfig], retry_state: RetryCallState, ) -> RunnableConfig: - if isinstance(config, list): - return [self._patch_config(c, retry_state) for c in config] - config = config or {} original_tags = config.get("tags") or [] return patch_config( @@ -53,6 +54,16 @@ class RunnableRetry(RunnableBinding[Input, Output]): + ["retry:attempt:{}".format(retry_state.attempt_number)], ) + def _patch_config_list( + self, + config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + retry_state: RetryCallState, + ) -> Union[RunnableConfig, List[RunnableConfig]]: + if isinstance(config, list): + return [self._patch_config(c, retry_state) for c in config] + + return self._patch_config(config, retry_state) + def invoke( self, input: Input, @@ -64,7 +75,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): result = super().invoke( input, self._patch_config(config, attempt.retry_state), **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -76,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): result = await super().ainvoke( input, self._patch_config(config, attempt.retry_state), **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -89,9 +100,11 @@ class RunnableRetry(RunnableBinding[Input, Output]): for attempt in self._sync_retrying(): with attempt: result = super().batch( - inputs, self._patch_config(config, attempt.retry_state), **kwargs + inputs, + self._patch_config_list(config, attempt.retry_state), + **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result @@ -104,8 +117,13 @@ class RunnableRetry(RunnableBinding[Input, Output]): async for attempt in self._async_retrying(): with attempt: result = await super().abatch( - inputs, self._patch_config(config, attempt.retry_state), **kwargs + inputs, + self._patch_config_list(config, attempt.retry_state), + **kwargs ) - if not attempt.retry_state.outcome.failed: + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) return result + + # stream() and transform() are not retried because retrying a stream + # is not very intuitive. diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 10a1ac00e3..eca512d0f6 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -6,6 +6,7 @@ import pytest from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion +from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks From 4eecf90f3346db73d5b52ab20e42699a2fefca79 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:30:36 +0200 Subject: [PATCH 3/9] Lint --- libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index eca512d0f6..ca22c6c984 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -42,7 +42,6 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) -from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt class FakeTracer(BaseTracer): From 85088dc5df12a77339728dbbefa3a151757b97bc Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 24 Aug 2023 18:54:33 +0200 Subject: [PATCH 4/9] Lint --- libs/langchain/langchain/schema/runnable/retry.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index 1c0a2b7882..c53621904f 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -65,10 +65,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): return self._patch_config(config, retry_state) def invoke( - self, - input: Input, - config: Optional[RunnableConfig] = None, - **kwargs: Any | None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: for attempt in self._sync_retrying(): with attempt: @@ -80,7 +77,7 @@ class RunnableRetry(RunnableBinding[Input, Output]): return result async def ainvoke( - self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: async for attempt in self._async_retrying(): with attempt: From af2e4ce2cd063c1e8fc485a978dbd2e1ecbfe0c9 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 25 Aug 2023 17:14:40 +0200 Subject: [PATCH 5/9] Use a non-inheritable tag --- .../langchain/schema/runnable/base.py | 7 +++ .../langchain/schema/runnable/config.py | 3 - .../langchain/schema/runnable/retry.py | 59 +++++++++++++++---- 3 files changed, 54 insertions(+), 15 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7b18062846..ad5c8cfe84 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config={**self.config, **(config or {}), **kwargs}, ) + def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound.with_retry(retry), + kwargs=self.kwargs, + config=self.config, + ) + def invoke( self, input: Input, diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 5752b09bf2..3f87f04403 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -98,7 +98,6 @@ def patch_config( recursion_limit: Optional[int] = None, max_concurrency: Optional[int] = None, run_name: Optional[str] = None, - tags: Optional[List[str]] = None, ) -> RunnableConfig: config = ensure_config(config) if deep_copy_locals: @@ -115,8 +114,6 @@ def patch_config( config["max_concurrency"] = max_concurrency if run_name is not None: config["run_name"] = run_name - if tags is not None: - config["tags"] = tags return config diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index c53621904f..b746d87428 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,10 +1,20 @@ -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.config import RunnableConfig, patch_config +if TYPE_CHECKING: + from langchain.callbacks.manager import ( + AsyncCallbackManager as AsyncCallbackManagerT, + CallbackManager as CallbackManagerT, + ) + + T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT) +else: + T = TypeVar("T") + class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" @@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]): self, config: Optional[RunnableConfig], retry_state: RetryCallState, + cm_cls: Type[T], ) -> RunnableConfig: config = config or {} - original_tags = config.get("tags") or [] - return patch_config( - config, - tags=original_tags - + ["retry:attempt:{}".format(retry_state.attempt_number)], + return ( + patch_config( + config, + callbacks=cm_cls.configure( + inheritable_callbacks=config.get("callbacks"), + local_tags=["retry:attempt:{}".format(retry_state.attempt_number)], + ), + ) + if retry_state.attempt_number > 1 + else config ) def _patch_config_list( self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], retry_state: RetryCallState, + cm_cls: Type[T], ) -> Union[RunnableConfig, List[RunnableConfig]]: if isinstance(config, list): - return [self._patch_config(c, retry_state) for c in config] + return [self._patch_config(c, retry_state, cm_cls) for c in config] - return self._patch_config(config, retry_state) + return self._patch_config(config, retry_state, cm_cls) def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: + from langchain.callbacks.manager import CallbackManager + for attempt in self._sync_retrying(): with attempt: result = super().invoke( - input, self._patch_config(config, attempt.retry_state), **kwargs + input, + self._patch_config(config, attempt.retry_state, CallbackManager), + **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: + from langchain.callbacks.manager import AsyncCallbackManager + async for attempt in self._async_retrying(): with attempt: result = await super().ainvoke( - input, self._patch_config(config, attempt.retry_state), **kwargs + input, + self._patch_config( + config, attempt.retry_state, AsyncCallbackManager + ), + **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Any ) -> List[Output]: + from langchain.callbacks.manager import CallbackManager + for attempt in self._sync_retrying(): with attempt: result = super().batch( inputs, - self._patch_config_list(config, attempt.retry_state), + self._patch_config_list( + config, attempt.retry_state, CallbackManager + ), **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: @@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]): config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, **kwargs: Any ) -> List[Output]: + from langchain.callbacks.manager import AsyncCallbackManager + async for attempt in self._async_retrying(): with attempt: result = await super().abatch( inputs, - self._patch_config_list(config, attempt.retry_state), + self._patch_config_list( + config, attempt.retry_state, AsyncCallbackManager + ), **kwargs ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: From 0eba80912f8fbcd6420a7119de1b41b97f629d18 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 25 Aug 2023 17:15:08 +0200 Subject: [PATCH 6/9] Lint --- libs/langchain/langchain/schema/runnable/retry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index b746d87428..ce67cca399 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -8,8 +8,8 @@ from langchain.schema.runnable.config import RunnableConfig, patch_config if TYPE_CHECKING: from langchain.callbacks.manager import ( AsyncCallbackManager as AsyncCallbackManagerT, - CallbackManager as CallbackManagerT, ) + from langchain.callbacks.manager import CallbackManager as CallbackManagerT T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT) else: From 4c0e1e501c11d565f5c70f616af6592d48b6df52 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 28 Aug 2023 17:18:36 +0200 Subject: [PATCH 7/9] Re-implement retry, adding a root run, and implement return_exception for batch() and abatch() --- libs/langchain/langchain/llms/base.py | 58 ++- .../langchain/schema/runnable/base.py | 394 ++++++++++++++++-- .../langchain/schema/runnable/retry.py | 289 ++++++++----- .../langchain/schema/runnable/router.py | 50 ++- .../schema/runnable/test_runnable.py | 340 ++++++++++++++- 5 files changed, 950 insertions(+), 181 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 6f7dcc2008..3724db869f 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -263,20 +263,28 @@ class BaseLLM(BaseLanguageModel[str], ABC): self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - max_concurrency: Optional[int] = None, + *, + return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: - llm_result = self.generate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] + try: + llm_result = self.generate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e else: batches = [ inputs[i : i + max_concurrency] @@ -285,33 +293,43 @@ class BaseLLM(BaseLanguageModel[str], ABC): return [ output for batch in batches - for output in self.batch(batch, config=config, **kwargs) + for output in self.batch( + batch, config=config, return_exceptions=return_exceptions, **kwargs + ) ] async def abatch( self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - max_concurrency: Optional[int] = None, + *, + return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: if type(self)._agenerate == BaseLLM._agenerate: # model doesn't implement async batch, so use default implementation return await asyncio.get_running_loop().run_in_executor( - None, self.batch, inputs, config, max_concurrency + None, partial(self.batch, **kwargs), inputs, config ) config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: - llm_result = await self.agenerate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] + try: + llm_result = await self.agenerate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e else: batches = [ inputs[i : i + max_concurrency] diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index ad5c8cfe84..7d65f0d461 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -27,8 +27,6 @@ from typing import ( cast, ) -from tenacity import BaseRetrying - if TYPE_CHECKING: from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -107,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -115,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC): """ configs = get_config_list(config, len(inputs)) + def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: + if return_exceptions: + try: + return self.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return self.invoke(input, config, **kwargs) + # If there's only one input, don't bother with the executor if len(inputs) == 1: - return [self.invoke(inputs[0], configs[0], **kwargs)] + return cast(List[Output], [invoke(inputs[0], configs[0])]) with get_executor_for_config(configs[0]) as executor: - return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) + return cast(List[Output], list(executor.map(invoke, inputs, configs))) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -133,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC): Subclasses should override this method if they can batch more efficiently. """ configs = get_config_list(config, len(inputs)) - coros = map(partial(self.ainvoke, **kwargs), inputs, configs) + async def ainvoke( + input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await self.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await self.ainvoke(input, config, **kwargs) + + coros = map(ainvoke, inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) def stream( @@ -230,11 +252,21 @@ class Runnable(Generic[Input, Output], ABC): def with_retry( self, - retry: BaseRetrying, + *, + retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,), + wait_exponential_jitter: bool = True, + stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: from langchain.schema.runnable.retry import RunnableRetry - return RunnableRetry(bound=self, retry=retry, kwargs={}, config={}) + return RunnableRetry( + bound=self, + kwargs={}, + config={}, + retry_if_exception_type=retry_if_exception_type, + wait_exponential_jitter=wait_exponential_jitter, + stop_after_attempt=stop_after_attempt, + ) def map(self) -> Runnable[List[Input], List[Output]]: """ @@ -341,6 +373,146 @@ class Runnable(Generic[Input, Output], ABC): await run_manager.on_chain_end(dumpd(output)) return output + def _batch_with_config( + self, + func: Union[ + Callable[[List[Input]], List[Union[Exception, Output]]], + Callable[ + [List[Input], List[CallbackManagerForChainRun]], + List[Union[Exception, Output]], + ], + Callable[ + [List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], + List[Union[Exception, Output]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + return_exceptions: bool = False, + run_type: Optional[str] = None, + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + configs = get_config_list(config, len(input)) + callback_managers = [get_callback_manager_for_config(c) for c in configs] + run_managers = [ + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ] + try: + if accepts_run_manager_and_config(func): + output = func( + input, + run_manager=run_managers, + config=configs, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = func(input, run_manager=run_managers) # type: ignore[call-arg] + else: + output = func(input) # type: ignore[call-arg] + + print("output", output) + except Exception as e: + for run_manager in run_managers: + run_manager.on_chain_error(e) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + + async def _abatch_with_config( + self, + func: Union[ + Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], + Callable[ + [List[Input], List[AsyncCallbackManagerForChainRun]], + Awaitable[List[Union[Exception, Output]]], + ], + Callable[ + [ + List[Input], + List[AsyncCallbackManagerForChainRun], + List[RunnableConfig], + ], + Awaitable[List[Union[Exception, Output]]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + return_exceptions: bool = False, + run_type: Optional[str] = None, + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + configs = get_config_list(config, len(input)) + callback_managers = [get_async_callback_manager_for_config(c) for c in configs] + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ) + ) + try: + if accepts_run_manager_and_config(func): + output = await func( + input, + run_manager=run_managers, + config=configs, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = await func(input, run_manager=run_managers) # type: ignore + else: + output = await func(input) # type: ignore[call-arg] + print("output", output) + except Exception as e: + await asyncio.gather( + *(run_manager.on_chain_error(e) for run_manager in run_managers) + ) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + def _transform_stream_with_config( self, input: Iterator[Input], @@ -596,10 +768,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager + if return_exceptions: + raise NotImplementedError() + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -656,10 +833,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import AsyncCallbackManager + if return_exceptions: + raise NotImplementedError() + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -841,6 +1023,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -871,29 +1055,90 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke try: - for step in self.steps: - inputs = step.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - ) + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + stepidx = -1 + for step in self.steps: + stepidx += 1 + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = step.batch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for step in self.steps: + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + ) + # finish the root runs except (KeyboardInterrupt, Exception) as e: for rm in run_managers: rm.on_chain_error(e) - raise + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise else: - for rm, input in zip(run_managers, inputs): - rm.on_chain_end(input) - return cast(List[Output], inputs) + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import ( @@ -929,24 +1174,83 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke .batch() on each step # this uses batching optimizations in Runnable subclasses, like LLM try: - for step in self.steps: - inputs = await step.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - ) + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + stepidx = -1 + for step in self.steps: + stepidx += 1 + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = await step.abatch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for step in self.steps: + inputs = await step.abatch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + ) # finish the root runs except (KeyboardInterrupt, Exception) as e: await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - raise + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise else: - await asyncio.gather( - *(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs)) - ) - return cast(List[Output], inputs) + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception def stream( self, @@ -1555,9 +1859,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config={**self.config, **(config or {}), **kwargs}, ) - def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]: + def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( - bound=self.bound.with_retry(retry), + bound=self.bound.with_retry(**kwargs), kwargs=self.kwargs, config=self.config, ) @@ -1590,6 +1894,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: if isinstance(config, list): @@ -1601,12 +1907,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ] - return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) + return self.bound.batch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: if isinstance(config, list): @@ -1618,7 +1931,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ] - return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs}) + return await self.bound.abatch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) def stream( self, diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index ce67cca399..cda41605aa 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,97 +1,113 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast -from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying +from tenacity import ( + AsyncRetrying, + RetryCallState, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.config import RunnableConfig, patch_config -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManager as AsyncCallbackManagerT, - ) - from langchain.callbacks.manager import CallbackManager as CallbackManagerT - - T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT) -else: - T = TypeVar("T") +T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) +U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry: BaseRetrying + retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,) + + wait_exponential_jitter: bool = True + + stop_after_attempt: int = 3 def _kwargs_retrying(self) -> Dict[str, Any]: - kwargs: Dict[str, Any] = {} - if self.retry.sleep is not None: - kwargs["sleep"] = self.retry.sleep - if self.retry.stop is not None: - kwargs["stop"] = self.retry.stop - if self.retry.wait is not None: - kwargs["wait"] = self.retry.wait - if self.retry.retry is not None: - kwargs["retry"] = self.retry.retry - if self.retry.before is not None: - kwargs["before"] = self.retry.before - if self.retry.after is not None: - kwargs["after"] = self.retry.after - if self.retry.before_sleep is not None: - kwargs["before_sleep"] = self.retry.before_sleep - if self.retry.reraise is not None: - kwargs["reraise"] = self.retry.reraise - if self.retry.retry_error_cls is not None: - kwargs["retry_error_cls"] = self.retry.retry_error_cls - if self.retry.retry_error_callback is not None: - kwargs["retry_error_callback"] = self.retry.retry_error_callback + kwargs: Dict[str, Any] = dict() + + if self.stop_after_attempt: + kwargs["stop"] = stop_after_attempt(self.stop_after_attempt) + + if self.wait_exponential_jitter: + kwargs["wait"] = wait_exponential_jitter() + + if self.retry_if_exception_type: + kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type) + return kwargs - def _sync_retrying(self) -> Retrying: - return Retrying(**self._kwargs_retrying()) + def _sync_retrying(self, **kwargs: Any) -> Retrying: + return Retrying(**self._kwargs_retrying(), **kwargs) - def _async_retrying(self) -> AsyncRetrying: - return AsyncRetrying(**self._kwargs_retrying()) + def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: + return AsyncRetrying(**self._kwargs_retrying(), **kwargs) def _patch_config( self, - config: Optional[RunnableConfig], + config: RunnableConfig, + run_manager: T, retry_state: RetryCallState, - cm_cls: Type[T], ) -> RunnableConfig: config = config or {} - return ( - patch_config( - config, - callbacks=cm_cls.configure( - inheritable_callbacks=config.get("callbacks"), - local_tags=["retry:attempt:{}".format(retry_state.attempt_number)], - ), - ) - if retry_state.attempt_number > 1 - else config + return patch_config( + config, + callbacks=run_manager.get_child( + "retry:attempt:{}".format(retry_state.attempt_number) + if retry_state.attempt_number > 1 + else None + ), ) def _patch_config_list( self, - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + config: List[RunnableConfig], + run_manager: List[T], retry_state: RetryCallState, - cm_cls: Type[T], - ) -> Union[RunnableConfig, List[RunnableConfig]]: - if isinstance(config, list): - return [self._patch_config(c, retry_state, cm_cls) for c in config] + ) -> List[RunnableConfig]: + return [ + self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) + ] - return self._patch_config(config, retry_state, cm_cls) + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + for attempt in self._sync_retrying(reraise=True): + with attempt: + result = super().invoke( + input, + self._patch_config(config, run_manager, attempt.retry_state), + ) + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain.callbacks.manager import CallbackManager + return self._call_with_config(self._invoke, input, config, **kwargs) - for attempt in self._sync_retrying(): + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + async for attempt in self._async_retrying(reraise=True): with attempt: - result = super().invoke( + result = await super().ainvoke( input, - self._patch_config(config, attempt.retry_state, CallbackManager), - **kwargs + self._patch_config(config, run_manager, attempt.retry_state), ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -100,62 +116,135 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) - async for attempt in self._async_retrying(): - with attempt: - result = await super().ainvoke( - input, - self._patch_config( - config, attempt.retry_state, AsyncCallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + def _batch( + self, + inputs: List[Input], + run_manager: List[CallbackManagerForChainRun], + config: List[RunnableConfig], + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + for attempt in self._sync_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = super().batch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs def batch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Any ) -> List[Output]: - from langchain.callbacks.manager import CallbackManager + return self._batch_with_config( + self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) - for attempt in self._sync_retrying(): - with attempt: - result = super().batch( - inputs, - self._patch_config_list( - config, attempt.retry_state, CallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + async def _abatch( + self, + inputs: List[Input], + run_manager: List[AsyncCallbackManagerForChainRun], + config: List[RunnableConfig], + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + async for attempt in self._async_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = await super().abatch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Any ) -> List[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - - async for attempt in self._async_retrying(): - with attempt: - result = await super().abatch( - inputs, - self._patch_config_list( - config, attempt.retry_state, AsyncCallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + return await self._abatch_with_config( + self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) # stream() and transform() are not retried because retrying a stream # is not very intuitive. diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 5277932543..a51d0907ca 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -1,6 +1,5 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncIterator, @@ -12,6 +11,7 @@ from typing import ( Optional, TypedDict, Union, + cast, ) from langchain.load.serializable import Serializable @@ -23,7 +23,11 @@ from langchain.schema.runnable.base import ( RunnableSequence, coerce_to_runnable, ) -from langchain.schema.runnable.config import RunnableConfig, get_config_list +from langchain.schema.runnable.config import ( + RunnableConfig, + get_config_list, + get_executor_for_config, +) from langchain.schema.runnable.utils import gather_with_concurrency @@ -122,7 +126,7 @@ class RouterRunnable( inputs: List[RouterInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, - max_concurrency: Optional[int] = None, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] @@ -130,16 +134,23 @@ class RouterRunnable( if any(key not in self.runnables for key in keys): raise ValueError("One or more keys do not have a corresponding runnable") + def invoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return runnable.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return runnable.invoke(input, config, **kwargs) + runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - return list( - executor.map( - lambda runnable, input, config: runnable.invoke(input, config), - runnables, - actual_inputs, - configs, - ) + with get_executor_for_config(configs[0]) as executor: + return cast( + List[Output], + list(executor.map(invoke, runnables, actual_inputs, configs)), ) async def abatch( @@ -147,7 +158,7 @@ class RouterRunnable( inputs: List[RouterInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, - max_concurrency: Optional[int] = None, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] @@ -155,12 +166,23 @@ class RouterRunnable( if any(key not in self.runnables for key in keys): raise ValueError("One or more keys do not have a corresponding runnable") + async def ainvoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await runnable.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await runnable.ainvoke(input, config, **kwargs) + runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) return await gather_with_concurrency( - max_concurrency, + configs[0].get("max_concurrency"), *( - runnable.ainvoke(input, config) + ainvoke(runnable, input, config) for runnable, input, config in zip(runnables, actual_inputs, configs) ), ) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index ca22c6c984..2e0be35ddc 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -6,7 +6,6 @@ import pytest from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks @@ -1444,22 +1443,345 @@ def test_retrying(mocker: MockerFixture) -> None: assert _lambda_mock.call_count == 1 _lambda_mock.reset_mock() - with pytest.raises(RetryError): + with pytest.raises(ValueError): runnable.with_retry( - Retrying( - stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) - ) + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), ).invoke(1) - assert _lambda_mock.call_count == 2 + assert _lambda_mock.call_count == 2 # retried _lambda_mock.reset_mock() with pytest.raises(RuntimeError): runnable.with_retry( - Retrying( - stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) - ) + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), ).invoke(2) + assert _lambda_mock.call_count == 1 # did not retry + _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).batch([1, 2, 0]) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + _lambda_mock.reset_mock() + + output = runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).batch([1, 2, 0], return_exceptions=True) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + assert len(output) == 3 + assert isinstance(output[0], ValueError) + assert isinstance(output[1], RuntimeError) + assert output[2] == 0 + _lambda_mock.reset_mock() + + +@pytest.mark.asyncio +async def test_async_retrying(mocker: MockerFixture) -> None: + def _lambda(x: int) -> Union[int, Runnable]: + if x == 1: + raise ValueError("x is 1") + elif x == 2: + raise RuntimeError("x is 2") + else: + return x + + _lambda_mock = mocker.Mock(side_effect=_lambda) + runnable = RunnableLambda(_lambda_mock) + + with pytest.raises(ValueError): + await runnable.ainvoke(1) + assert _lambda_mock.call_count == 1 _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).ainvoke(1) + + assert _lambda_mock.call_count == 2 # retried + _lambda_mock.reset_mock() + + with pytest.raises(RuntimeError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).ainvoke(2) + + assert _lambda_mock.call_count == 1 # did not retry + _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).abatch([1, 2, 0]) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + _lambda_mock.reset_mock() + + output = await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).abatch([1, 2, 0], return_exceptions=True) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + assert len(output) == 3 + assert isinstance(output[0], ValueError) + assert isinstance(output[1], RuntimeError) + assert output[2] == 0 + _lambda_mock.reset_mock() + + +@freeze_time("2023-01-01") +def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: + class ControlledExceptionRunnable(Runnable[str, str]): + def __init__(self, fail_starts_with: str) -> None: + self.fail_starts_with = fail_starts_with + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + raise NotImplementedError() + + def _batch( + self, + inputs: List[str], + ) -> List: + outputs: List[Any] = [] + for input in inputs: + if input.startswith(self.fail_starts_with): + outputs.append(ValueError()) + else: + outputs.append(input + "a") + return outputs + + def batch( + self, + inputs: List[str], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + return self._batch_with_config( + self._batch, + inputs, + config, + return_exceptions=return_exceptions, + **kwargs, + ) + + chain = ( + ControlledExceptionRunnable("bux") + | ControlledExceptionRunnable("bar") + | ControlledExceptionRunnable("baz") + | ControlledExceptionRunnable("foo") + ) + + assert isinstance(chain, RunnableSequence) + + # Test batch + with pytest.raises(ValueError): + chain.batch(["foo", "bar", "baz", "qux"]) + + spy = mocker.spy(ControlledExceptionRunnable, "batch") + tracer = FakeTracer() + inputs = ["foo", "bar", "baz", "qux"] + outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True) + assert len(outputs) == 4 + assert isinstance(outputs[0], ValueError) + assert isinstance(outputs[1], ValueError) + assert isinstance(outputs[2], ValueError) + assert outputs[3] == "quxaaaa" + assert spy.call_count == 4 + inputs_to_batch = [c[0][1] for c in spy.call_args_list] + assert inputs_to_batch == [ + # inputs to sequence step 0 + # same as inputs to sequence.batch() + ["foo", "bar", "baz", "qux"], + # inputs to sequence step 1 + # == outputs of sequence step 0 as no exceptions were raised + ["fooa", "bara", "baza", "quxa"], + # inputs to sequence step 2 + # 'bar' was dropped as it raised an exception in step 1 + ["fooaa", "bazaa", "quxaa"], + # inputs to sequence step 3 + # 'baz' was dropped as it raised an exception in step 2 + ["fooaaa", "quxaaa"], + ] + parent_runs = sorted( + (r for r in tracer.runs if r.parent_run_id is None), + key=lambda run: inputs.index(run.inputs["input"]), + ) + assert len(parent_runs) == 4 + + parent_run_foo = parent_runs[0] + assert parent_run_foo.inputs["input"] == "foo" + assert parent_run_foo.error == repr(ValueError()) + assert len(parent_run_foo.child_runs) == 4 + assert [r.error for r in parent_run_foo.child_runs] == [ + None, + None, + None, + repr(ValueError()), + ] + + parent_run_bar = parent_runs[1] + assert parent_run_bar.inputs["input"] == "bar" + assert parent_run_bar.error == repr(ValueError()) + assert len(parent_run_bar.child_runs) == 2 + assert [r.error for r in parent_run_bar.child_runs] == [ + None, + repr(ValueError()), + ] + + parent_run_baz = parent_runs[2] + assert parent_run_baz.inputs["input"] == "baz" + assert parent_run_baz.error == repr(ValueError()) + assert len(parent_run_baz.child_runs) == 3 + assert [r.error for r in parent_run_baz.child_runs] == [ + None, + None, + repr(ValueError()), + ] + + parent_run_qux = parent_runs[3] + assert parent_run_qux.inputs["input"] == "qux" + assert parent_run_qux.error is None + assert parent_run_qux.outputs["output"] == "quxaaaa" + assert len(parent_run_qux.child_runs) == 4 + assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None] + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: + class ControlledExceptionRunnable(Runnable[str, str]): + def __init__(self, fail_starts_with: str) -> None: + self.fail_starts_with = fail_starts_with + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + raise NotImplementedError() + + async def _abatch( + self, + inputs: List[str], + ) -> List: + outputs: List[Any] = [] + for input in inputs: + if input.startswith(self.fail_starts_with): + outputs.append(ValueError()) + else: + outputs.append(input + "a") + return outputs + + async def abatch( + self, + inputs: List[str], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + return await self._abatch_with_config( + self._abatch, + inputs, + config, + return_exceptions=return_exceptions, + **kwargs, + ) + + chain = ( + ControlledExceptionRunnable("bux") + | ControlledExceptionRunnable("bar") + | ControlledExceptionRunnable("baz") + | ControlledExceptionRunnable("foo") + ) + + assert isinstance(chain, RunnableSequence) + + # Test abatch + with pytest.raises(ValueError): + await chain.abatch(["foo", "bar", "baz", "qux"]) + + spy = mocker.spy(ControlledExceptionRunnable, "abatch") + tracer = FakeTracer() + inputs = ["foo", "bar", "baz", "qux"] + outputs = await chain.abatch( + inputs, dict(callbacks=[tracer]), return_exceptions=True + ) + assert len(outputs) == 4 + assert isinstance(outputs[0], ValueError) + assert isinstance(outputs[1], ValueError) + assert isinstance(outputs[2], ValueError) + assert outputs[3] == "quxaaaa" + assert spy.call_count == 4 + inputs_to_batch = [c[0][1] for c in spy.call_args_list] + assert inputs_to_batch == [ + # inputs to sequence step 0 + # same as inputs to sequence.batch() + ["foo", "bar", "baz", "qux"], + # inputs to sequence step 1 + # == outputs of sequence step 0 as no exceptions were raised + ["fooa", "bara", "baza", "quxa"], + # inputs to sequence step 2 + # 'bar' was dropped as it raised an exception in step 1 + ["fooaa", "bazaa", "quxaa"], + # inputs to sequence step 3 + # 'baz' was dropped as it raised an exception in step 2 + ["fooaaa", "quxaaa"], + ] + parent_runs = sorted( + (r for r in tracer.runs if r.parent_run_id is None), + key=lambda run: inputs.index(run.inputs["input"]), + ) + assert len(parent_runs) == 4 + + parent_run_foo = parent_runs[0] + assert parent_run_foo.inputs["input"] == "foo" + assert parent_run_foo.error == repr(ValueError()) + assert len(parent_run_foo.child_runs) == 4 + assert [r.error for r in parent_run_foo.child_runs] == [ + None, + None, + None, + repr(ValueError()), + ] + + parent_run_bar = parent_runs[1] + assert parent_run_bar.inputs["input"] == "bar" + assert parent_run_bar.error == repr(ValueError()) + assert len(parent_run_bar.child_runs) == 2 + assert [r.error for r in parent_run_bar.child_runs] == [ + None, + repr(ValueError()), + ] + + parent_run_baz = parent_runs[2] + assert parent_run_baz.inputs["input"] == "baz" + assert parent_run_baz.error == repr(ValueError()) + assert len(parent_run_baz.child_runs) == 3 + assert [r.error for r in parent_run_baz.child_runs] == [ + None, + None, + repr(ValueError()), + ] + + parent_run_qux = parent_runs[3] + assert parent_run_qux.inputs["input"] == "qux" + assert parent_run_qux.error is None + assert parent_run_qux.outputs["output"] == "quxaaaa" + assert len(parent_run_qux.child_runs) == 4 + assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None] From 7966af1e9c9c7106b8540b50075b8bbe8bf3a214 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 28 Aug 2023 17:37:43 +0200 Subject: [PATCH 8/9] Lint --- libs/langchain/langchain/schema/runnable/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7d65f0d461..f428dcdf1d 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -417,8 +417,6 @@ class Runnable(Generic[Input, Output], ABC): output = func(input, run_manager=run_managers) # type: ignore[call-arg] else: output = func(input) # type: ignore[call-arg] - - print("output", output) except Exception as e: for run_manager in run_managers: run_manager.on_chain_error(e) @@ -489,7 +487,6 @@ class Runnable(Generic[Input, Output], ABC): output = await func(input, run_manager=run_managers) # type: ignore else: output = await func(input) # type: ignore[call-arg] - print("output", output) except Exception as e: await asyncio.gather( *(run_manager.on_chain_error(e) for run_manager in run_managers) From 63306899a2c0a4c4811c8966788c58126c0f657b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 1 Sep 2023 15:30:14 +0100 Subject: [PATCH 9/9] PR review suggestions --- .../langchain/schema/runnable/base.py | 10 +++---- .../langchain/schema/runnable/retry.py | 29 ++++++++----------- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index f428dcdf1d..fb000ae1e7 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -263,9 +263,9 @@ class Runnable(Generic[Input, Output], ABC): bound=self, kwargs={}, config={}, - retry_if_exception_type=retry_if_exception_type, + retry_exception_types=retry_if_exception_type, wait_exponential_jitter=wait_exponential_jitter, - stop_after_attempt=stop_after_attempt, + max_attempt_number=stop_after_attempt, ) def map(self) -> Runnable[List[Input], List[Output]]: @@ -388,6 +388,7 @@ class Runnable(Generic[Input, Output], ABC): ], input: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, return_exceptions: bool = False, run_type: Optional[str] = None, ) -> List[Output]: @@ -456,6 +457,7 @@ class Runnable(Generic[Input, Output], ABC): ], input: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, return_exceptions: bool = False, run_type: Optional[str] = None, ) -> List[Output]: @@ -1057,9 +1059,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # If an input has failed it will be present in this map, # and the value will be the exception that was raised. failed_inputs_map: Dict[int, Exception] = {} - stepidx = -1 for step in self.steps: - stepidx += 1 # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) remaining_idxs = [ @@ -1176,9 +1176,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # If an input has failed it will be present in this map, # and the value will be the exception that was raised. failed_inputs_map: Dict[int, Exception] = {} - stepidx = -1 for step in self.steps: - stepidx += 1 # Assemble the original indexes of the remaining inputs # (i.e. the ones that haven't failed yet) remaining_idxs = [ diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index cda41605aa..37de03f600 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -24,31 +24,32 @@ U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,) + retry_exception_types: Tuple[Type[BaseException]] = (Exception,) wait_exponential_jitter: bool = True - stop_after_attempt: int = 3 + max_attempt_number: int = 3 + @property def _kwargs_retrying(self) -> Dict[str, Any]: kwargs: Dict[str, Any] = dict() - if self.stop_after_attempt: - kwargs["stop"] = stop_after_attempt(self.stop_after_attempt) + if self.max_attempt_number: + kwargs["stop"] = stop_after_attempt(self.max_attempt_number) if self.wait_exponential_jitter: kwargs["wait"] = wait_exponential_jitter() - if self.retry_if_exception_type: - kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type) + if self.retry_exception_types: + kwargs["retry"] = retry_if_exception_type(self.retry_exception_types) return kwargs def _sync_retrying(self, **kwargs: Any) -> Retrying: - return Retrying(**self._kwargs_retrying(), **kwargs) + return Retrying(**self._kwargs_retrying, **kwargs) def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: - return AsyncRetrying(**self._kwargs_retrying(), **kwargs) + return AsyncRetrying(**self._kwargs_retrying, **kwargs) def _patch_config( self, @@ -56,15 +57,9 @@ class RunnableRetry(RunnableBinding[Input, Output]): run_manager: T, retry_state: RetryCallState, ) -> RunnableConfig: - config = config or {} - return patch_config( - config, - callbacks=run_manager.get_child( - "retry:attempt:{}".format(retry_state.attempt_number) - if retry_state.attempt_number > 1 - else None - ), - ) + attempt = retry_state.attempt_number + tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None + return patch_config(config, callbacks=run_manager.get_child(tag)) def _patch_config_list( self,