|
|
|
@ -1,10 +1,20 @@
|
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
|
|
|
|
|
|
|
|
|
|
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
|
|
|
|
|
|
|
|
|
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
|
|
|
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManager as AsyncCallbackManagerT,
|
|
|
|
|
CallbackManager as CallbackManagerT,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
|
|
|
|
|
else:
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RunnableRetry(RunnableBinding[Input, Output]):
|
|
|
|
|
"""Retry a Runnable if it fails."""
|
|
|
|
@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|
|
|
|
self,
|
|
|
|
|
config: Optional[RunnableConfig],
|
|
|
|
|
retry_state: RetryCallState,
|
|
|
|
|
cm_cls: Type[T],
|
|
|
|
|
) -> RunnableConfig:
|
|
|
|
|
config = config or {}
|
|
|
|
|
original_tags = config.get("tags") or []
|
|
|
|
|
return patch_config(
|
|
|
|
|
config,
|
|
|
|
|
tags=original_tags
|
|
|
|
|
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
|
|
|
|
|
return (
|
|
|
|
|
patch_config(
|
|
|
|
|
config,
|
|
|
|
|
callbacks=cm_cls.configure(
|
|
|
|
|
inheritable_callbacks=config.get("callbacks"),
|
|
|
|
|
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
if retry_state.attempt_number > 1
|
|
|
|
|
else config
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _patch_config_list(
|
|
|
|
|
self,
|
|
|
|
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
|
|
|
|
retry_state: RetryCallState,
|
|
|
|
|
cm_cls: Type[T],
|
|
|
|
|
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
|
|
|
|
if isinstance(config, list):
|
|
|
|
|
return [self._patch_config(c, retry_state) for c in config]
|
|
|
|
|
return [self._patch_config(c, retry_state, cm_cls) for c in config]
|
|
|
|
|
|
|
|
|
|
return self._patch_config(config, retry_state)
|
|
|
|
|
return self._patch_config(config, retry_state, cm_cls)
|
|
|
|
|
|
|
|
|
|
def invoke(
|
|
|
|
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
|
|
|
) -> Output:
|
|
|
|
|
from langchain.callbacks.manager import CallbackManager
|
|
|
|
|
|
|
|
|
|
for attempt in self._sync_retrying():
|
|
|
|
|
with attempt:
|
|
|
|
|
result = super().invoke(
|
|
|
|
|
input, self._patch_config(config, attempt.retry_state), **kwargs
|
|
|
|
|
input,
|
|
|
|
|
self._patch_config(config, attempt.retry_state, CallbackManager),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
|
|
|
|
attempt.retry_state.set_result(result)
|
|
|
|
@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|
|
|
|
async def ainvoke(
|
|
|
|
|
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
|
|
|
|
) -> Output:
|
|
|
|
|
from langchain.callbacks.manager import AsyncCallbackManager
|
|
|
|
|
|
|
|
|
|
async for attempt in self._async_retrying():
|
|
|
|
|
with attempt:
|
|
|
|
|
result = await super().ainvoke(
|
|
|
|
|
input, self._patch_config(config, attempt.retry_state), **kwargs
|
|
|
|
|
input,
|
|
|
|
|
self._patch_config(
|
|
|
|
|
config, attempt.retry_state, AsyncCallbackManager
|
|
|
|
|
),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
|
|
|
|
attempt.retry_state.set_result(result)
|
|
|
|
@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|
|
|
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> List[Output]:
|
|
|
|
|
from langchain.callbacks.manager import CallbackManager
|
|
|
|
|
|
|
|
|
|
for attempt in self._sync_retrying():
|
|
|
|
|
with attempt:
|
|
|
|
|
result = super().batch(
|
|
|
|
|
inputs,
|
|
|
|
|
self._patch_config_list(config, attempt.retry_state),
|
|
|
|
|
self._patch_config_list(
|
|
|
|
|
config, attempt.retry_state, CallbackManager
|
|
|
|
|
),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
|
|
|
@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
|
|
|
|
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
|
|
|
|
**kwargs: Any
|
|
|
|
|
) -> List[Output]:
|
|
|
|
|
from langchain.callbacks.manager import AsyncCallbackManager
|
|
|
|
|
|
|
|
|
|
async for attempt in self._async_retrying():
|
|
|
|
|
with attempt:
|
|
|
|
|
result = await super().abatch(
|
|
|
|
|
inputs,
|
|
|
|
|
self._patch_config_list(config, attempt.retry_state),
|
|
|
|
|
self._patch_config_list(
|
|
|
|
|
config, attempt.retry_state, AsyncCallbackManager
|
|
|
|
|
),
|
|
|
|
|
**kwargs
|
|
|
|
|
)
|
|
|
|
|
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
|
|
|
|