mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Lint
This commit is contained in:
parent
b2ac835466
commit
2242e2160f
@ -1,50 +1,51 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
from langchain.schema.runnable.base import Input, Output, Runnable, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
"""Retry a Runnable if it fails."""
|
||||
|
||||
retry: BaseRetrying
|
||||
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self.retry.sleep is not None:
|
||||
kwargs["sleep"] = self.retry.sleep
|
||||
if self.retry.stop is not None:
|
||||
kwargs["stop"] = self.retry.stop
|
||||
if self.retry.wait is not None:
|
||||
kwargs["wait"] = self.retry.wait
|
||||
if self.retry.retry is not None:
|
||||
kwargs["retry"] = self.retry.retry
|
||||
if self.retry.before is not None:
|
||||
kwargs["before"] = self.retry.before
|
||||
if self.retry.after is not None:
|
||||
kwargs["after"] = self.retry.after
|
||||
if self.retry.before_sleep is not None:
|
||||
kwargs["before_sleep"] = self.retry.before_sleep
|
||||
if self.retry.reraise is not None:
|
||||
kwargs["reraise"] = self.retry.reraise
|
||||
if self.retry.retry_error_cls is not None:
|
||||
kwargs["retry_error_cls"] = self.retry.retry_error_cls
|
||||
if self.retry.retry_error_callback is not None:
|
||||
kwargs["retry_error_callback"] = self.retry.retry_error_callback
|
||||
return kwargs
|
||||
|
||||
def _sync_retrying(self) -> Retrying:
|
||||
return Retrying(
|
||||
sleep=self.retry.sleep,
|
||||
stop=self.retry.stop,
|
||||
wait=self.retry.wait,
|
||||
retry=self.retry.retry,
|
||||
before=self.retry.before,
|
||||
after=self.retry.after,
|
||||
before_sleep=self.retry.before_sleep,
|
||||
reraise=self.retry.reraise,
|
||||
retry_error_cls=self.retry.retry_error_cls,
|
||||
retry_error_callback=self.retry.retry_error_callback,
|
||||
)
|
||||
return Retrying(**self._kwargs_retrying())
|
||||
|
||||
def _async_retrying(self) -> AsyncRetrying:
|
||||
return AsyncRetrying(
|
||||
sleep=self.retry.sleep,
|
||||
stop=self.retry.stop,
|
||||
wait=self.retry.wait,
|
||||
retry=self.retry.retry,
|
||||
before=self.retry.before,
|
||||
after=self.retry.after,
|
||||
before_sleep=self.retry.before_sleep,
|
||||
reraise=self.retry.reraise,
|
||||
retry_error_cls=self.retry.retry_error_cls,
|
||||
retry_error_callback=self.retry.retry_error_callback,
|
||||
)
|
||||
return AsyncRetrying(**self._kwargs_retrying())
|
||||
|
||||
def _patch_config(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
config: Optional[RunnableConfig],
|
||||
retry_state: RetryCallState,
|
||||
) -> RunnableConfig:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state) for c in config]
|
||||
|
||||
config = config or {}
|
||||
original_tags = config.get("tags") or []
|
||||
return patch_config(
|
||||
@ -53,6 +54,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
)
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
retry_state: RetryCallState,
|
||||
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state) for c in config]
|
||||
|
||||
return self._patch_config(config, retry_state)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
@ -64,7 +75,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
result = super().invoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
@ -76,7 +87,7 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
result = await super().ainvoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
@ -89,9 +100,11 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().batch(
|
||||
inputs, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
inputs,
|
||||
self._patch_config_list(config, attempt.retry_state),
|
||||
**kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
@ -104,8 +117,13 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().abatch(
|
||||
inputs, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
inputs,
|
||||
self._patch_config_list(config, attempt.retry_state),
|
||||
**kwargs
|
||||
)
|
||||
if not attempt.retry_state.outcome.failed:
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
# stream() and transform() are not retried because retrying a stream
|
||||
# is not very intuitive.
|
||||
|
@ -6,6 +6,7 @@ import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
|
Loading…
Reference in New Issue
Block a user