Nc/runnables retry (#9711)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-09-01 15:52:20 +01:00 committed by GitHub
commit b1c87da2b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1044 additions and 67 deletions

View File

@ -263,20 +263,28 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
try:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]
@ -285,33 +293,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [
output
for batch in batches
for output in self.batch(batch, config=config, **kwargs)
for output in self.batch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
)
]
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, self.batch, inputs, config, max_concurrency
None, partial(self.batch, **kwargs), inputs, config
)
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
try:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]

View File

@ -105,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -113,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC):
"""
configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
if return_exceptions:
try:
return self.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return self.invoke(input, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0], **kwargs)]
return cast(List[Output], [invoke(inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -131,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can batch more efficiently.
"""
configs = get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
async def ainvoke(
input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await self.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await self.ainvoke(input, config, **kwargs)
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
@ -226,6 +250,24 @@ class Runnable(Generic[Input, Output], ABC):
bound=self, config={**(config or {}), **kwargs}, kwargs={}
)
def with_retry(
self,
*,
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,),
wait_exponential_jitter: bool = True,
stop_after_attempt: int = 3,
) -> Runnable[Input, Output]:
from langchain.schema.runnable.retry import RunnableRetry
return RunnableRetry(
bound=self,
kwargs={},
config={},
retry_exception_types=retry_if_exception_type,
wait_exponential_jitter=wait_exponential_jitter,
max_attempt_number=stop_after_attempt,
)
def map(self) -> Runnable[List[Input], List[Output]]:
"""
Return a new Runnable that maps a list of inputs to a list of outputs,
@ -331,6 +373,145 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(dumpd(output))
return output
def _batch_with_config(
self,
func: Union[
Callable[[List[Input]], List[Union[Exception, Output]]],
Callable[
[List[Input], List[CallbackManagerForChainRun]],
List[Union[Exception, Output]],
],
Callable[
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
List[Union[Exception, Output]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
configs = get_config_list(config, len(input))
callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [
callback_manager.on_chain_start(
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
)
]
try:
if accepts_run_manager_and_config(func):
output = func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
except Exception as e:
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in input])
else:
raise
else:
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception):
first_exception = first_exception or out
run_manager.on_chain_error(out)
else:
run_manager.on_chain_end(dumpd(out))
if return_exceptions or first_exception is None:
return cast(List[Output], output)
else:
raise first_exception
async def _abatch_with_config(
self,
func: Union[
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
Callable[
[List[Input], List[AsyncCallbackManagerForChainRun]],
Awaitable[List[Union[Exception, Output]]],
],
Callable[
[
List[Input],
List[AsyncCallbackManagerForChainRun],
List[RunnableConfig],
],
Awaitable[List[Union[Exception, Output]]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
callback_manager.on_chain_start(
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
)
)
)
try:
if accepts_run_manager_and_config(func):
output = await func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(input, run_manager=run_managers) # type: ignore
else:
output = await func(input) # type: ignore[call-arg]
except Exception as e:
await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
if return_exceptions:
return cast(List[Output], [e for _ in input])
else:
raise
else:
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception):
first_exception = first_exception or out
coros.append(run_manager.on_chain_error(out))
else:
coros.append(run_manager.on_chain_end(dumpd(out)))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
else:
raise first_exception
def _transform_stream_with_config(
self,
input: Iterator[Input],
@ -586,10 +767,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -646,10 +832,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -831,6 +1022,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -861,29 +1054,88 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke
try:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
if return_exceptions:
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
for step in self.steps:
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
remaining_idxs = [
i for i in range(len(configs)) if i not in failed_inputs_map
]
# Invoke the step on the remaining inputs
inputs = step.batch(
[
inp
for i, inp in zip(remaining_idxs, inputs)
if i not in failed_inputs_map
],
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
break
# Reassemble the outputs, inserting Exceptions for failed inputs
inputs_copy = inputs.copy()
inputs = []
for i in range(len(configs)):
if i in failed_inputs_map:
inputs.append(cast(Input, failed_inputs_map[i]))
else:
inputs.append(inputs_copy.pop(0))
else:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
rm.on_chain_error(e)
raise
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
else:
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input)
return cast(List[Output], inputs)
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception):
first_exception = first_exception or out
run_manager.on_chain_error(out)
else:
run_manager.on_chain_end(dumpd(out))
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
else:
raise first_exception
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
@ -919,24 +1171,81 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke .batch() on each step
# this uses batching optimizations in Runnable subclasses, like LLM
try:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
if return_exceptions:
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
for step in self.steps:
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
remaining_idxs = [
i for i in range(len(configs)) if i not in failed_inputs_map
]
# Invoke the step on the remaining inputs
inputs = await step.abatch(
[
inp
for i, inp in zip(remaining_idxs, inputs)
if i not in failed_inputs_map
],
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
break
# Reassemble the outputs, inserting Exceptions for failed inputs
inputs_copy = inputs.copy()
inputs = []
for i in range(len(configs)):
if i in failed_inputs_map:
inputs.append(cast(Input, failed_inputs_map[i]))
else:
inputs.append(inputs_copy.pop(0))
else:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
raise
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
else:
raise
else:
await asyncio.gather(
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
)
return cast(List[Output], inputs)
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception):
first_exception = first_exception or out
coros.append(run_manager.on_chain_error(out))
else:
coros.append(run_manager.on_chain_end(dumpd(out)))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
else:
raise first_exception
def stream(
self,
@ -1545,6 +1854,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config={**self.config, **(config or {}), **kwargs},
)
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)
def invoke(
self,
input: Input,
@ -1573,6 +1889,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if isinstance(config, list):
@ -1584,12 +1902,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
return self.bound.batch(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if isinstance(config, list):
@ -1601,7 +1926,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
return await self.bound.abatch(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
def stream(
self,

View File

@ -0,0 +1,245 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
from tenacity import (
AsyncRetrying,
RetryCallState,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")
class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails."""
retry_exception_types: Tuple[Type[BaseException]] = (Exception,)
wait_exponential_jitter: bool = True
max_attempt_number: int = 3
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()
if self.max_attempt_number:
kwargs["stop"] = stop_after_attempt(self.max_attempt_number)
if self.wait_exponential_jitter:
kwargs["wait"] = wait_exponential_jitter()
if self.retry_exception_types:
kwargs["retry"] = retry_if_exception_type(self.retry_exception_types)
return kwargs
def _sync_retrying(self, **kwargs: Any) -> Retrying:
return Retrying(**self._kwargs_retrying, **kwargs)
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
return AsyncRetrying(**self._kwargs_retrying, **kwargs)
def _patch_config(
self,
config: RunnableConfig,
run_manager: T,
retry_state: RetryCallState,
) -> RunnableConfig:
attempt = retry_state.attempt_number
tag = "retry:attempt:{}".format(attempt) if attempt > 1 else None
return patch_config(config, callbacks=run_manager.get_child(tag))
def _patch_config_list(
self,
config: List[RunnableConfig],
run_manager: List[T],
retry_state: RetryCallState,
) -> List[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
for attempt in self._sync_retrying(reraise=True):
with attempt:
result = super().invoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return self._call_with_config(self._invoke, input, config, **kwargs)
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
async for attempt in self._async_retrying(reraise=True):
with attempt:
result = await super().ainvoke(
input,
self._patch_config(config, run_manager, attempt.retry_state),
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _batch(
self,
inputs: List[Input],
run_manager: List[CallbackManagerForChainRun],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = super().batch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any
) -> List[Output]:
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
async def _abatch(
self,
inputs: List[Input],
run_manager: List[AsyncCallbackManagerForChainRun],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = await super().abatch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any
) -> List[Output]:
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
# stream() and transform() are not retried because retrying a stream
# is not very intuitive.

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
@ -12,6 +11,7 @@ from typing import (
Optional,
TypedDict,
Union,
cast,
)
from langchain.load.serializable import Serializable
@ -23,7 +23,11 @@ from langchain.schema.runnable.base import (
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import RunnableConfig, get_config_list
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain.schema.runnable.utils import gather_with_concurrency
@ -122,7 +126,7 @@ class RouterRunnable(
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
keys = [input["key"] for input in inputs]
@ -130,16 +134,23 @@ class RouterRunnable(
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
def invoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return runnable.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return runnable.invoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)
async def abatch(
@ -147,7 +158,7 @@ class RouterRunnable(
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
keys = [input["key"] for input in inputs]
@ -155,12 +166,23 @@ class RouterRunnable(
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await runnable.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await runnable.ainvoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
return await gather_with_concurrency(
max_concurrency,
configs[0].get("max_concurrency"),
*(
runnable.ainvoke(input, config)
ainvoke(runnable, input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)

View File

@ -141,7 +141,7 @@ async def test_with_config(mocker: MockerFixture) -> None:
else:
assert call.args[2].get("tags") == ["b-tag"]
assert call.args[2].get("max_concurrency") == 5
spy_seq_step.reset_mock()
mocker.stop(spy_seq_step)
assert [
*fake.with_config(tags=["a-tag"]).stream(
@ -1423,3 +1423,365 @@ def test_recursive_lambda() -> None:
with pytest.raises(RecursionError):
runnable.invoke(0, {"recursion_limit": 9})
def test_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1:
raise ValueError("x is 1")
elif x == 2:
raise RuntimeError("x is 2")
else:
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
runnable.invoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).invoke(1)
assert _lambda_mock.call_count == 2 # retried
_lambda_mock.reset_mock()
with pytest.raises(RuntimeError):
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).invoke(2)
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0])
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
_lambda_mock.reset_mock()
output = runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0], return_exceptions=True)
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
assert len(output) == 3
assert isinstance(output[0], ValueError)
assert isinstance(output[1], RuntimeError)
assert output[2] == 0
_lambda_mock.reset_mock()
@pytest.mark.asyncio
async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1:
raise ValueError("x is 1")
elif x == 2:
raise RuntimeError("x is 2")
else:
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
await runnable.ainvoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).ainvoke(1)
assert _lambda_mock.call_count == 2 # retried
_lambda_mock.reset_mock()
with pytest.raises(RuntimeError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).ainvoke(2)
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0])
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
_lambda_mock.reset_mock()
output = await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0], return_exceptions=True)
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
assert len(output) == 3
assert isinstance(output[0], ValueError)
assert isinstance(output[1], RuntimeError)
assert output[2] == 0
_lambda_mock.reset_mock()
@freeze_time("2023-01-01")
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
def _batch(
self,
inputs: List[str],
) -> List:
outputs: List[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
else:
outputs.append(input + "a")
return outputs
def batch(
self,
inputs: List[str],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
return self._batch_with_config(
self._batch,
inputs,
config,
return_exceptions=return_exceptions,
**kwargs,
)
chain = (
ControlledExceptionRunnable("bux")
| ControlledExceptionRunnable("bar")
| ControlledExceptionRunnable("baz")
| ControlledExceptionRunnable("foo")
)
assert isinstance(chain, RunnableSequence)
# Test batch
with pytest.raises(ValueError):
chain.batch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "batch")
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError)
assert isinstance(outputs[2], ValueError)
assert outputs[3] == "quxaaaa"
assert spy.call_count == 4
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
assert inputs_to_batch == [
# inputs to sequence step 0
# same as inputs to sequence.batch()
["foo", "bar", "baz", "qux"],
# inputs to sequence step 1
# == outputs of sequence step 0 as no exceptions were raised
["fooa", "bara", "baza", "quxa"],
# inputs to sequence step 2
# 'bar' was dropped as it raised an exception in step 1
["fooaa", "bazaa", "quxaa"],
# inputs to sequence step 3
# 'baz' was dropped as it raised an exception in step 2
["fooaaa", "quxaaa"],
]
parent_runs = sorted(
(r for r in tracer.runs if r.parent_run_id is None),
key=lambda run: inputs.index(run.inputs["input"]),
)
assert len(parent_runs) == 4
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
None,
None,
None,
repr(ValueError()),
]
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
None,
None,
repr(ValueError()),
]
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
assert parent_run_qux.error is None
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
async def _abatch(
self,
inputs: List[str],
) -> List:
outputs: List[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
else:
outputs.append(input + "a")
return outputs
async def abatch(
self,
inputs: List[str],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
return await self._abatch_with_config(
self._abatch,
inputs,
config,
return_exceptions=return_exceptions,
**kwargs,
)
chain = (
ControlledExceptionRunnable("bux")
| ControlledExceptionRunnable("bar")
| ControlledExceptionRunnable("baz")
| ControlledExceptionRunnable("foo")
)
assert isinstance(chain, RunnableSequence)
# Test abatch
with pytest.raises(ValueError):
await chain.abatch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = await chain.abatch(
inputs, dict(callbacks=[tracer]), return_exceptions=True
)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError)
assert isinstance(outputs[2], ValueError)
assert outputs[3] == "quxaaaa"
assert spy.call_count == 4
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
assert inputs_to_batch == [
# inputs to sequence step 0
# same as inputs to sequence.batch()
["foo", "bar", "baz", "qux"],
# inputs to sequence step 1
# == outputs of sequence step 0 as no exceptions were raised
["fooa", "bara", "baza", "quxa"],
# inputs to sequence step 2
# 'bar' was dropped as it raised an exception in step 1
["fooaa", "bazaa", "quxaa"],
# inputs to sequence step 3
# 'baz' was dropped as it raised an exception in step 2
["fooaaa", "quxaaa"],
]
parent_runs = sorted(
(r for r in tracer.runs if r.parent_run_id is None),
key=lambda run: inputs.index(run.inputs["input"]),
)
assert len(parent_runs) == 4
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
None,
None,
None,
repr(ValueError()),
]
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
None,
None,
repr(ValueError()),
]
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
assert parent_run_qux.error is None
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]