Use a non-inheritable tag

This commit is contained in:
Nuno Campos 2023-08-25 17:14:40 +02:00
parent 85088dc5df
commit af2e4ce2cd
3 changed files with 54 additions and 15 deletions

View File

@ -1555,6 +1555,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config={**self.config, **(config or {}), **kwargs},
)
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(retry),
kwargs=self.kwargs,
config=self.config,
)
def invoke(
self,
input: Input,

View File

@ -98,7 +98,6 @@ def patch_config(
recursion_limit: Optional[int] = None,
max_concurrency: Optional[int] = None,
run_name: Optional[str] = None,
tags: Optional[List[str]] = None,
) -> RunnableConfig:
config = ensure_config(config)
if deep_copy_locals:
@ -115,8 +114,6 @@ def patch_config(
config["max_concurrency"] = max_concurrency
if run_name is not None:
config["run_name"] = run_name
if tags is not None:
config["tags"] = tags
return config

View File

@ -1,10 +1,20 @@
from typing import Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManager as AsyncCallbackManagerT,
CallbackManager as CallbackManagerT,
)
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
else:
T = TypeVar("T")
class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails."""
@ -45,32 +55,43 @@ class RunnableRetry(RunnableBinding[Input, Output]):
self,
config: Optional[RunnableConfig],
retry_state: RetryCallState,
cm_cls: Type[T],
) -> RunnableConfig:
config = config or {}
original_tags = config.get("tags") or []
return patch_config(
return (
patch_config(
config,
tags=original_tags
+ ["retry:attempt:{}".format(retry_state.attempt_number)],
callbacks=cm_cls.configure(
inheritable_callbacks=config.get("callbacks"),
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
),
)
if retry_state.attempt_number > 1
else config
)
def _patch_config_list(
self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
retry_state: RetryCallState,
cm_cls: Type[T],
) -> Union[RunnableConfig, List[RunnableConfig]]:
if isinstance(config, list):
return [self._patch_config(c, retry_state) for c in config]
return [self._patch_config(c, retry_state, cm_cls) for c in config]
return self._patch_config(config, retry_state)
return self._patch_config(config, retry_state, cm_cls)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain.callbacks.manager import CallbackManager
for attempt in self._sync_retrying():
with attempt:
result = super().invoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
input,
self._patch_config(config, attempt.retry_state, CallbackManager),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
@ -79,10 +100,16 @@ class RunnableRetry(RunnableBinding[Input, Output]):
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
async for attempt in self._async_retrying():
with attempt:
result = await super().ainvoke(
input, self._patch_config(config, attempt.retry_state), **kwargs
input,
self._patch_config(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
@ -94,11 +121,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
for attempt in self._sync_retrying():
with attempt:
result = super().batch(
inputs,
self._patch_config_list(config, attempt.retry_state),
self._patch_config_list(
config, attempt.retry_state, CallbackManager
),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
@ -111,11 +142,15 @@ class RunnableRetry(RunnableBinding[Input, Output]):
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
**kwargs: Any
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
async for attempt in self._async_retrying():
with attempt:
result = await super().abatch(
inputs,
self._patch_config_list(config, attempt.retry_state),
self._patch_config_list(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: