This commit is contained in:
Nuno Campos 2023-08-24 18:24:13 +02:00
parent b2ac835466
commit 2242e2160f
2 changed files with 56 additions and 37 deletions

View File

@ -1,50 +1,51 @@
from typing import Any, List, Optional, Union
from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
from typing import Any, Dict, List, Optional, Union
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails."""
retry: BaseRetrying
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
if self.retry.sleep is not None:
kwargs["sleep"] = self.retry.sleep
if self.retry.stop is not None:
kwargs["stop"] = self.retry.stop
if self.retry.wait is not None:
kwargs["wait"] = self.retry.wait
if self.retry.retry is not None:
kwargs["retry"] = self.retry.retry
if self.retry.before is not None:
kwargs["before"] = self.retry.before
if self.retry.after is not None:
kwargs["after"] = self.retry.after
if self.retry.before_sleep is not None:
kwargs["before_sleep"] = self.retry.before_sleep
if self.retry.reraise is not None:
kwargs["reraise"] = self.retry.reraise
if self.retry.retry_error_cls is not None:
kwargs["retry_error_cls"] = self.retry.retry_error_cls
if self.retry.retry_error_callback is not None:
kwargs["retry_error_callback"] = self.retry.retry_error_callback
return kwargs
def _sync_retrying(self) -> Retrying:
return Retrying(
sleep=self.retry.sleep,
stop=self.retry.stop,
wait=self.retry.wait,
retry=self.retry.retry,
before=self.retry.before,
after=self.retry.after,
before_sleep=self.retry.before_sleep,
reraise=self.retry.reraise,
retry_error_cls=self.retry.retry_error_cls,
retry_error_callback=self.retry.retry_error_callback,
)
return Retrying(**self._kwargs_retrying())
def _async_retrying(self) -> AsyncRetrying:
return AsyncRetrying(
sleep=self.retry.sleep,
stop=self.retry.stop,
wait=self.retry.wait,
retry=self.retry.retry,
before=self.retry.before,
after=self.retry.after,
before_sleep=self.retry.before_sleep,
reraise=self.retry.reraise,
retry_error_cls=self.retry.retry_error_cls,
retry_error_callback=self.retry.retry_error_callback,
)
return AsyncRetrying(**self._kwargs_retrying())
def _patch_config(
self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
config: Optional[RunnableConfig],
retry_state: RetryCallState,
) -> RunnableConfig:
if isinstance(config, list):
return [self._patch_config(c, retry_state) for c in config]
config = config or {}
original_tags = config.get("tags") or []
return patch_config(
@ -53,6 +54,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
)
def _patch_config_list(
self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
retry_state: RetryCallState,
) -> Union[RunnableConfig, List[RunnableConfig]]:
if isinstance(config, list):
return [self._patch_config(c, retry_state) for c in config]
return self._patch_config(config, retry_state)
def invoke(
self,
input: Input,
@ -64,7 +75,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
result = super().invoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
@ -76,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
result = await super().ainvoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
@ -89,9 +100,11 @@ class RunnableRetry(RunnableBinding[Input, Output]):
for attempt in self._sync_retrying():
with attempt:
result = super().batch(
inputs, self._patch_config(config, attempt.retry_state), **kwargs
inputs,
self._patch_config_list(config, attempt.retry_state),
**kwargs
)
if not attempt.retry_state.outcome.failed:
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
@ -104,8 +117,13 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async for attempt in self._async_retrying():
with attempt:
result = await super().abatch(
inputs, self._patch_config(config, attempt.retry_state), **kwargs
inputs,
self._patch_config_list(config, attempt.retry_state),
**kwargs
)
if not attempt.retry_state.outcome.failed:
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
# stream() and transform() are not retried because retrying a stream
# is not very intuitive.

View File

@ -6,6 +6,7 @@ import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
from langchain import PromptTemplate
from langchain.callbacks.manager import Callbacks