Add .with_retry() to Runnables

pull/9711/head
Nuno Campos 1 year ago
parent 50a5c5bcf8
commit b2ac835466

@ -27,6 +27,8 @@ from typing import (
cast,
)
from tenacity import BaseRetrying
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -226,6 +228,14 @@ class Runnable(Generic[Input, Output], ABC):
bound=self, config={**(config or {}), **kwargs}, kwargs={}
)
def with_retry(
self,
retry: BaseRetrying,
) -> Runnable[Input, Output]:
from langchain.schema.runnable.retry import RunnableRetry
return RunnableRetry(bound=self, retry=retry, kwargs={}, config={})
def map(self) -> Runnable[List[Input], List[Output]]:
"""
Return a new Runnable that maps a list of inputs to a list of outputs,

@ -98,6 +98,7 @@ def patch_config(
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> RunnableConfig:
config = ensure_config(config)
if deep_copy_locals:
@ -114,6 +115,8 @@ def patch_config(
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if tags is not None:
config["tags"] = tags
return config

@ -0,0 +1,111 @@
from typing import Any, List, Optional, Union
from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails."""
retry: BaseRetrying
def _sync_retrying(self) -> Retrying:
return Retrying(
sleep=self.retry.sleep,
stop=self.retry.stop,
wait=self.retry.wait,
retry=self.retry.retry,
before=self.retry.before,
after=self.retry.after,
before_sleep=self.retry.before_sleep,
reraise=self.retry.reraise,
retry_error_cls=self.retry.retry_error_cls,
retry_error_callback=self.retry.retry_error_callback,
)
def _async_retrying(self) -> AsyncRetrying:
return AsyncRetrying(
sleep=self.retry.sleep,
stop=self.retry.stop,
wait=self.retry.wait,
retry=self.retry.retry,
before=self.retry.before,
after=self.retry.after,
before_sleep=self.retry.before_sleep,
reraise=self.retry.reraise,
retry_error_cls=self.retry.retry_error_cls,
retry_error_callback=self.retry.retry_error_callback,
)
def _patch_config(
self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
retry_state: RetryCallState,
) -> RunnableConfig:
if isinstance(config, list):
return [self._patch_config(c, retry_state) for c in config]
config = config or {}
original_tags = config.get("tags") or []
return patch_config(
config,
tags=original_tags
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
)
def invoke(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None
) -> Output:
for attempt in self._sync_retrying():
with attempt:
result = super().invoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any | None
) -> Output:
async for attempt in self._async_retrying():
with attempt:
result = await super().ainvoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any
) -> List[Output]:
for attempt in self._sync_retrying():
with attempt:
result = super().batch(
inputs, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any
) -> List[Output]:
async for attempt in self._async_retrying():
with attempt:
result = await super().abatch(
inputs, self._patch_config(config, attempt.retry_state), **kwargs
)
if not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result

@ -41,6 +41,7 @@ from langchain.schema.runnable import (
RunnableSequence,
RunnableWithFallbacks,
)
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
class FakeTracer(BaseTracer):
@ -141,7 +142,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
else:
assert call.args[2].get("tags") == ["b-tag"]
assert call.args[2].get("max_concurrency") == 5
spy_seq_step.reset_mock()
mocker.stop(spy_seq_step)
assert [
*fake.with_config(tags=["a-tag"]).stream(
@ -1423,3 +1424,42 @@ def test_recursive_lambda() -> None:
with pytest.raises(RecursionError):
runnable.invoke(0, {"recursion_limit": 9})
def test_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1:
raise ValueError("x is 1")
elif x == 2:
raise RuntimeError("x is 2")
else:
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
runnable.invoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(RetryError):
runnable.with_retry(
Retrying(
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
)
).invoke(1)
assert _lambda_mock.call_count == 2
_lambda_mock.reset_mock()
with pytest.raises(RuntimeError):
runnable.with_retry(
Retrying(
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
)
).invoke(2)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()

Loading…
Cancel
Save