mirror of https://github.com/hwchase17/langchain
Re-implement retry, adding a root run, and implement return_exception for batch() and abatch()
parent
0eba80912f
commit
4c0e1e501c
@ -1,161 +1,250 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
|
||||
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryCallState,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager as AsyncCallbackManagerT,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManager as CallbackManagerT
|
||||
|
||||
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
|
||||
else:
|
||||
T = TypeVar("T")
|
||||
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
"""Retry a Runnable if it fails."""
|
||||
|
||||
retry: BaseRetrying
|
||||
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,)
|
||||
|
||||
wait_exponential_jitter: bool = True
|
||||
|
||||
stop_after_attempt: int = 3
|
||||
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self.retry.sleep is not None:
|
||||
kwargs["sleep"] = self.retry.sleep
|
||||
if self.retry.stop is not None:
|
||||
kwargs["stop"] = self.retry.stop
|
||||
if self.retry.wait is not None:
|
||||
kwargs["wait"] = self.retry.wait
|
||||
if self.retry.retry is not None:
|
||||
kwargs["retry"] = self.retry.retry
|
||||
if self.retry.before is not None:
|
||||
kwargs["before"] = self.retry.before
|
||||
if self.retry.after is not None:
|
||||
kwargs["after"] = self.retry.after
|
||||
if self.retry.before_sleep is not None:
|
||||
kwargs["before_sleep"] = self.retry.before_sleep
|
||||
if self.retry.reraise is not None:
|
||||
kwargs["reraise"] = self.retry.reraise
|
||||
if self.retry.retry_error_cls is not None:
|
||||
kwargs["retry_error_cls"] = self.retry.retry_error_cls
|
||||
if self.retry.retry_error_callback is not None:
|
||||
kwargs["retry_error_callback"] = self.retry.retry_error_callback
|
||||
kwargs: Dict[str, Any] = dict()
|
||||
|
||||
if self.stop_after_attempt:
|
||||
kwargs["stop"] = stop_after_attempt(self.stop_after_attempt)
|
||||
|
||||
if self.wait_exponential_jitter:
|
||||
kwargs["wait"] = wait_exponential_jitter()
|
||||
|
||||
if self.retry_if_exception_type:
|
||||
kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _sync_retrying(self) -> Retrying:
|
||||
return Retrying(**self._kwargs_retrying())
|
||||
def _sync_retrying(self, **kwargs: Any) -> Retrying:
|
||||
return Retrying(**self._kwargs_retrying(), **kwargs)
|
||||
|
||||
def _async_retrying(self) -> AsyncRetrying:
|
||||
return AsyncRetrying(**self._kwargs_retrying())
|
||||
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
|
||||
return AsyncRetrying(**self._kwargs_retrying(), **kwargs)
|
||||
|
||||
def _patch_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig],
|
||||
config: RunnableConfig,
|
||||
run_manager: T,
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> RunnableConfig:
|
||||
config = config or {}
|
||||
return (
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=cm_cls.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
),
|
||||
)
|
||||
if retry_state.attempt_number > 1
|
||||
else config
|
||||
return patch_config(
|
||||
config,
|
||||
callbacks=run_manager.get_child(
|
||||
"retry:attempt:{}".format(retry_state.attempt_number)
|
||||
if retry_state.attempt_number > 1
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
config: List[RunnableConfig],
|
||||
run_manager: List[T],
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state, cm_cls) for c in config]
|
||||
|
||||
return self._patch_config(config, retry_state, cm_cls)
|
||||
) -> List[RunnableConfig]:
|
||||
return [
|
||||
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||
]
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
for attempt in self._sync_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input,
|
||||
self._patch_config(config, attempt.retry_state, CallbackManager),
|
||||
**kwargs
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
async for attempt in self._async_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input,
|
||||
self._patch_config(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
),
|
||||
**kwargs
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[CallbackManagerForChainRun],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = super().batch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
return self._batch_with_config(
|
||||
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().batch(
|
||||
inputs,
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, CallbackManager
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[AsyncCallbackManagerForChainRun],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = await super().abatch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().abatch(
|
||||
inputs,
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
return await self._abatch_with_config(
|
||||
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
|
||||
# stream() and transform() are not retried because retrying a stream
|
||||
# is not very intuitive.
|
||||
|
Loading…
Reference in New Issue