mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Use a non-inheritable tag
This commit is contained in:
parent
85088dc5df
commit
af2e4ce2cd
@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config={**self.config, **(config or {}), **kwargs},
|
||||
)
|
||||
|
||||
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound.with_retry(retry),
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: Input,
|
||||
|
@ -98,7 +98,6 @@ def patch_config(
|
||||
recursion_limit: Optional[int] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
) -> RunnableConfig:
|
||||
config = ensure_config(config)
|
||||
if deep_copy_locals:
|
||||
@ -115,8 +114,6 @@ def patch_config(
|
||||
config["max_concurrency"] = max_concurrency
|
||||
if run_name is not None:
|
||||
config["run_name"] = run_name
|
||||
if tags is not None:
|
||||
config["tags"] = tags
|
||||
return config
|
||||
|
||||
|
||||
|
@ -1,10 +1,20 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager as AsyncCallbackManagerT,
|
||||
CallbackManager as CallbackManagerT,
|
||||
)
|
||||
|
||||
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
|
||||
else:
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
"""Retry a Runnable if it fails."""
|
||||
@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
self,
|
||||
config: Optional[RunnableConfig],
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> RunnableConfig:
|
||||
config = config or {}
|
||||
original_tags = config.get("tags") or []
|
||||
return patch_config(
|
||||
config,
|
||||
tags=original_tags
|
||||
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
return (
|
||||
patch_config(
|
||||
config,
|
||||
callbacks=cm_cls.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
),
|
||||
)
|
||||
if retry_state.attempt_number > 1
|
||||
else config
|
||||
)
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state) for c in config]
|
||||
return [self._patch_config(c, retry_state, cm_cls) for c in config]
|
||||
|
||||
return self._patch_config(config, retry_state)
|
||||
return self._patch_config(config, retry_state, cm_cls)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
input,
|
||||
self._patch_config(config, attempt.retry_state, CallbackManager),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input, self._patch_config(config, attempt.retry_state), **kwargs
|
||||
input,
|
||||
self._patch_config(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().batch(
|
||||
inputs,
|
||||
self._patch_config_list(config, attempt.retry_state),
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, CallbackManager
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().abatch(
|
||||
inputs,
|
||||
self._patch_config_list(config, attempt.retry_state),
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
),
|
||||
**kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
|
Loading…
Reference in New Issue
Block a user