Re-implement retry, adding a root run, and implement return_exception for batch() and abatch()

pull/9711/head
Nuno Campos 1 year ago
parent 0eba80912f
commit 4c0e1e501c

@ -263,20 +263,28 @@ class BaseLLM(BaseLanguageModel[str], ABC):
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
try:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]
@ -285,33 +293,43 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return [
output
for batch in batches
for output in self.batch(batch, config=config, **kwargs)
for output in self.batch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
)
]
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, self.batch, inputs, config, max_concurrency
None, partial(self.batch, **kwargs), inputs, config
)
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
if max_concurrency is None:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
try:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
**kwargs,
)
return [g[0].text for g in llm_result.generations]
except Exception as e:
if return_exceptions:
return cast(List[str], [e for _ in inputs])
else:
raise e
else:
batches = [
inputs[i : i + max_concurrency]

@ -27,8 +27,6 @@ from typing import (
cast,
)
from tenacity import BaseRetrying
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -107,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -115,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC):
"""
configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
if return_exceptions:
try:
return self.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return self.invoke(input, config, **kwargs)
# If there's only one input, don't bother with the executor
if len(inputs) == 1:
return [self.invoke(inputs[0], configs[0], **kwargs)]
return cast(List[Output], [invoke(inputs[0], configs[0])])
with get_executor_for_config(configs[0]) as executor:
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
"""
@ -133,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC):
Subclasses should override this method if they can batch more efficiently.
"""
configs = get_config_list(config, len(inputs))
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
async def ainvoke(
input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await self.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await self.ainvoke(input, config, **kwargs)
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
def stream(
@ -230,11 +252,21 @@ class Runnable(Generic[Input, Output], ABC):
def with_retry(
self,
retry: BaseRetrying,
*,
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,),
wait_exponential_jitter: bool = True,
stop_after_attempt: int = 3,
) -> Runnable[Input, Output]:
from langchain.schema.runnable.retry import RunnableRetry
return RunnableRetry(bound=self, retry=retry, kwargs={}, config={})
return RunnableRetry(
bound=self,
kwargs={},
config={},
retry_if_exception_type=retry_if_exception_type,
wait_exponential_jitter=wait_exponential_jitter,
stop_after_attempt=stop_after_attempt,
)
def map(self) -> Runnable[List[Input], List[Output]]:
"""
@ -341,6 +373,146 @@ class Runnable(Generic[Input, Output], ABC):
await run_manager.on_chain_end(dumpd(output))
return output
def _batch_with_config(
self,
func: Union[
Callable[[List[Input]], List[Union[Exception, Output]]],
Callable[
[List[Input], List[CallbackManagerForChainRun]],
List[Union[Exception, Output]],
],
Callable[
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
List[Union[Exception, Output]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
return_exceptions: bool = False,
run_type: Optional[str] = None,
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
configs = get_config_list(config, len(input))
callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [
callback_manager.on_chain_start(
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
)
]
try:
if accepts_run_manager_and_config(func):
output = func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
else:
output = func(input) # type: ignore[call-arg]
print("output", output)
except Exception as e:
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
return cast(List[Output], [e for _ in input])
else:
raise
else:
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception):
first_exception = first_exception or out
run_manager.on_chain_error(out)
else:
run_manager.on_chain_end(dumpd(out))
if return_exceptions or first_exception is None:
return cast(List[Output], output)
else:
raise first_exception
async def _abatch_with_config(
self,
func: Union[
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
Callable[
[List[Input], List[AsyncCallbackManagerForChainRun]],
Awaitable[List[Union[Exception, Output]]],
],
Callable[
[
List[Input],
List[AsyncCallbackManagerForChainRun],
List[RunnableConfig],
],
Awaitable[List[Union[Exception, Output]]],
],
],
input: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
return_exceptions: bool = False,
run_type: Optional[str] = None,
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
callback_manager.on_chain_start(
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
)
)
)
try:
if accepts_run_manager_and_config(func):
output = await func(
input,
run_manager=run_managers,
config=configs,
) # type: ignore[call-arg]
elif accepts_run_manager(func):
output = await func(input, run_manager=run_managers) # type: ignore
else:
output = await func(input) # type: ignore[call-arg]
print("output", output)
except Exception as e:
await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
if return_exceptions:
return cast(List[Output], [e for _ in input])
else:
raise
else:
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
for run_manager, out in zip(run_managers, output):
if isinstance(out, Exception):
first_exception = first_exception or out
coros.append(run_manager.on_chain_error(out))
else:
coros.append(run_manager.on_chain_end(dumpd(out)))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], output)
else:
raise first_exception
def _transform_stream_with_config(
self,
input: Iterator[Input],
@ -596,10 +768,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if return_exceptions:
raise NotImplementedError()
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -656,10 +833,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
if return_exceptions:
raise NotImplementedError()
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -841,6 +1023,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
@ -871,29 +1055,90 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke
try:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
if return_exceptions:
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
stepidx = -1
for step in self.steps:
stepidx += 1
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
remaining_idxs = [
i for i in range(len(configs)) if i not in failed_inputs_map
]
# Invoke the step on the remaining inputs
inputs = step.batch(
[
inp
for i, inp in zip(remaining_idxs, inputs)
if i not in failed_inputs_map
],
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
break
# Reassemble the outputs, inserting Exceptions for failed inputs
inputs_copy = inputs.copy()
inputs = []
for i in range(len(configs)):
if i in failed_inputs_map:
inputs.append(cast(Input, failed_inputs_map[i]))
else:
inputs.append(inputs_copy.pop(0))
else:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
rm.on_chain_error(e)
raise
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
else:
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input)
return cast(List[Output], inputs)
first_exception: Optional[Exception] = None
for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception):
first_exception = first_exception or out
run_manager.on_chain_error(out)
else:
run_manager.on_chain_end(dumpd(out))
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
else:
raise first_exception
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
from langchain.callbacks.manager import (
@ -929,24 +1174,83 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke .batch() on each step
# this uses batching optimizations in Runnable subclasses, like LLM
try:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
if return_exceptions:
# Track which inputs (by index) failed so far
# If an input has failed it will be present in this map,
# and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {}
stepidx = -1
for step in self.steps:
stepidx += 1
# Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet)
remaining_idxs = [
i for i in range(len(configs)) if i not in failed_inputs_map
]
# Invoke the step on the remaining inputs
inputs = await step.abatch(
[
inp
for i, inp in zip(remaining_idxs, inputs)
if i not in failed_inputs_map
],
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
if isinstance(inp, Exception):
failed_inputs_map[i] = inp
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
# If all inputs have failed, stop processing
if len(failed_inputs_map) == len(configs):
break
# Reassemble the outputs, inserting Exceptions for failed inputs
inputs_copy = inputs.copy()
inputs = []
for i in range(len(configs)):
if i in failed_inputs_map:
inputs.append(cast(Input, failed_inputs_map[i]))
else:
inputs.append(inputs_copy.pop(0))
else:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child())
for rm, config in zip(run_managers, configs)
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
raise
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
else:
raise
else:
await asyncio.gather(
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
)
return cast(List[Output], inputs)
first_exception: Optional[Exception] = None
coros: List[Awaitable[None]] = []
for run_manager, out in zip(run_managers, inputs):
if isinstance(out, Exception):
first_exception = first_exception or out
coros.append(run_manager.on_chain_error(out))
else:
coros.append(run_manager.on_chain_end(dumpd(out)))
await asyncio.gather(*coros)
if return_exceptions or first_exception is None:
return cast(List[Output], inputs)
else:
raise first_exception
def stream(
self,
@ -1555,9 +1859,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
config={**self.config, **(config or {}), **kwargs},
)
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(retry),
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)
@ -1590,6 +1894,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if isinstance(config, list):
@ -1601,12 +1907,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
return self.bound.batch(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if isinstance(config, list):
@ -1618,7 +1931,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
patch_config(self._merge_config(config), deep_copy_locals=True)
for _ in range(len(inputs))
]
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
return await self.bound.abatch(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
def stream(
self,

@ -1,161 +1,250 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
from tenacity import (
AsyncRetrying,
RetryCallState,
RetryError,
Retrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.config import RunnableConfig, patch_config
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManager as AsyncCallbackManagerT,
)
from langchain.callbacks.manager import CallbackManager as CallbackManagerT
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
else:
T = TypeVar("T")
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
U = TypeVar("U")
class RunnableRetry(RunnableBinding[Input, Output]):
"""Retry a Runnable if it fails."""
retry: BaseRetrying
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,)
wait_exponential_jitter: bool = True
stop_after_attempt: int = 3
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
if self.retry.sleep is not None:
kwargs["sleep"] = self.retry.sleep
if self.retry.stop is not None:
kwargs["stop"] = self.retry.stop
if self.retry.wait is not None:
kwargs["wait"] = self.retry.wait
if self.retry.retry is not None:
kwargs["retry"] = self.retry.retry
if self.retry.before is not None:
kwargs["before"] = self.retry.before
if self.retry.after is not None:
kwargs["after"] = self.retry.after
if self.retry.before_sleep is not None:
kwargs["before_sleep"] = self.retry.before_sleep
if self.retry.reraise is not None:
kwargs["reraise"] = self.retry.reraise
if self.retry.retry_error_cls is not None:
kwargs["retry_error_cls"] = self.retry.retry_error_cls
if self.retry.retry_error_callback is not None:
kwargs["retry_error_callback"] = self.retry.retry_error_callback
kwargs: Dict[str, Any] = dict()
if self.stop_after_attempt:
kwargs["stop"] = stop_after_attempt(self.stop_after_attempt)
if self.wait_exponential_jitter:
kwargs["wait"] = wait_exponential_jitter()
if self.retry_if_exception_type:
kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type)
return kwargs
def _sync_retrying(self) -> Retrying:
return Retrying(**self._kwargs_retrying())
def _sync_retrying(self, **kwargs: Any) -> Retrying:
return Retrying(**self._kwargs_retrying(), **kwargs)
def _async_retrying(self) -> AsyncRetrying:
return AsyncRetrying(**self._kwargs_retrying())
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
return AsyncRetrying(**self._kwargs_retrying(), **kwargs)
def _patch_config(
self,
config: Optional[RunnableConfig],
config: RunnableConfig,
run_manager: T,
retry_state: RetryCallState,
cm_cls: Type[T],
) -> RunnableConfig:
config = config or {}
return (
patch_config(
config,
callbacks=cm_cls.configure(
inheritable_callbacks=config.get("callbacks"),
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
),
)
if retry_state.attempt_number > 1
else config
return patch_config(
config,
callbacks=run_manager.get_child(
"retry:attempt:{}".format(retry_state.attempt_number)
if retry_state.attempt_number > 1
else None
),
)
def _patch_config_list(
self,
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
config: List[RunnableConfig],
run_manager: List[T],
retry_state: RetryCallState,
cm_cls: Type[T],
) -> Union[RunnableConfig, List[RunnableConfig]]:
if isinstance(config, list):
return [self._patch_config(c, retry_state, cm_cls) for c in config]
return self._patch_config(config, retry_state, cm_cls)
) -> List[RunnableConfig]:
return [
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
]
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
def _invoke(
self,
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
from langchain.callbacks.manager import CallbackManager
for attempt in self._sync_retrying():
for attempt in self._sync_retrying(reraise=True):
with attempt:
result = super().invoke(
input,
self._patch_config(config, attempt.retry_state, CallbackManager),
**kwargs
self._patch_config(config, run_manager, attempt.retry_state),
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def ainvoke(
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
return self._call_with_config(self._invoke, input, config, **kwargs)
async for attempt in self._async_retrying():
async def _ainvoke(
self,
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
) -> Output:
async for attempt in self._async_retrying(reraise=True):
with attempt:
result = await super().ainvoke(
input,
self._patch_config(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs
self._patch_config(config, run_manager, attempt.retry_state),
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
def _batch(
self,
inputs: List[Input],
run_manager: List[CallbackManagerForChainRun],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
for attempt in self._sync_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = super().batch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
return self._batch_with_config(
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
for attempt in self._sync_retrying():
with attempt:
result = super().batch(
inputs,
self._patch_config_list(
config, attempt.retry_state, CallbackManager
),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
async def _abatch(
self,
inputs: List[Input],
run_manager: List[AsyncCallbackManagerForChainRun],
config: List[RunnableConfig],
) -> List[Union[Output, Exception]]:
results_map: Dict[int, Output] = {}
def pending(iterable: List[U]) -> List[U]:
return [item for idx, item in enumerate(iterable) if idx not in results_map]
try:
async for attempt in self._async_retrying():
with attempt:
# Get the results of the inputs that have not succeeded yet.
result = await super().abatch(
pending(inputs),
self._patch_config_list(
pending(config), pending(run_manager), attempt.retry_state
),
return_exceptions=True,
)
# Register the results of the inputs that have succeeded.
first_exception = None
for i, r in enumerate(result):
if isinstance(r, Exception):
if not first_exception:
first_exception = r
continue
results_map[i] = r
# If any exception occurred, raise it, to retry the failed ones
if first_exception:
raise first_exception
if (
attempt.retry_state.outcome
and not attempt.retry_state.outcome.failed
):
attempt.retry_state.set_result(result)
except RetryError as e:
try:
result
except UnboundLocalError:
result = cast(List[Output], [e] * len(inputs))
outputs: List[Union[Output, Exception]] = []
for idx, _ in enumerate(inputs):
if idx in results_map:
outputs.append(results_map[idx])
else:
outputs.append(result.pop(0))
return outputs
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any
) -> List[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
async for attempt in self._async_retrying():
with attempt:
result = await super().abatch(
inputs,
self._patch_config_list(
config, attempt.retry_state, AsyncCallbackManager
),
**kwargs
)
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
attempt.retry_state.set_result(result)
return result
return await self._abatch_with_config(
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
)
# stream() and transform() are not retried because retrying a stream
# is not very intuitive.

@ -1,6 +1,5 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
@ -12,6 +11,7 @@ from typing import (
Optional,
TypedDict,
Union,
cast,
)
from langchain.load.serializable import Serializable
@ -23,7 +23,11 @@ from langchain.schema.runnable.base import (
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import RunnableConfig, get_config_list
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain.schema.runnable.utils import gather_with_concurrency
@ -122,7 +126,7 @@ class RouterRunnable(
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
keys = [input["key"] for input in inputs]
@ -130,16 +134,23 @@ class RouterRunnable(
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
def invoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return runnable.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return runnable.invoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)
async def abatch(
@ -147,7 +158,7 @@ class RouterRunnable(
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
keys = [input["key"] for input in inputs]
@ -155,12 +166,23 @@ class RouterRunnable(
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await runnable.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await runnable.ainvoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
return await gather_with_concurrency(
max_concurrency,
configs[0].get("max_concurrency"),
*(
runnable.ainvoke(input, config)
ainvoke(runnable, input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)

@ -6,7 +6,6 @@ import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
from langchain import PromptTemplate
from langchain.callbacks.manager import Callbacks
@ -1444,22 +1443,345 @@ def test_retrying(mocker: MockerFixture) -> None:
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(RetryError):
with pytest.raises(ValueError):
runnable.with_retry(
Retrying(
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
)
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).invoke(1)
assert _lambda_mock.call_count == 2
assert _lambda_mock.call_count == 2 # retried
_lambda_mock.reset_mock()
with pytest.raises(RuntimeError):
runnable.with_retry(
Retrying(
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
)
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).invoke(2)
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0])
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
_lambda_mock.reset_mock()
output = runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).batch([1, 2, 0], return_exceptions=True)
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
assert len(output) == 3
assert isinstance(output[0], ValueError)
assert isinstance(output[1], RuntimeError)
assert output[2] == 0
_lambda_mock.reset_mock()
@pytest.mark.asyncio
async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
if x == 1:
raise ValueError("x is 1")
elif x == 2:
raise RuntimeError("x is 2")
else:
return x
_lambda_mock = mocker.Mock(side_effect=_lambda)
runnable = RunnableLambda(_lambda_mock)
with pytest.raises(ValueError):
await runnable.ainvoke(1)
assert _lambda_mock.call_count == 1
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).ainvoke(1)
assert _lambda_mock.call_count == 2 # retried
_lambda_mock.reset_mock()
with pytest.raises(RuntimeError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).ainvoke(2)
assert _lambda_mock.call_count == 1 # did not retry
_lambda_mock.reset_mock()
with pytest.raises(ValueError):
await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0])
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
_lambda_mock.reset_mock()
output = await runnable.with_retry(
stop_after_attempt=2,
retry_if_exception_type=(ValueError,),
).abatch([1, 2, 0], return_exceptions=True)
# 3rd input isn't retried because it succeeded
assert _lambda_mock.call_count == 3 + 2
assert len(output) == 3
assert isinstance(output[0], ValueError)
assert isinstance(output[1], RuntimeError)
assert output[2] == 0
_lambda_mock.reset_mock()
@freeze_time("2023-01-01")
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
def _batch(
self,
inputs: List[str],
) -> List:
outputs: List[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
else:
outputs.append(input + "a")
return outputs
def batch(
self,
inputs: List[str],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
return self._batch_with_config(
self._batch,
inputs,
config,
return_exceptions=return_exceptions,
**kwargs,
)
chain = (
ControlledExceptionRunnable("bux")
| ControlledExceptionRunnable("bar")
| ControlledExceptionRunnable("baz")
| ControlledExceptionRunnable("foo")
)
assert isinstance(chain, RunnableSequence)
# Test batch
with pytest.raises(ValueError):
chain.batch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "batch")
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError)
assert isinstance(outputs[2], ValueError)
assert outputs[3] == "quxaaaa"
assert spy.call_count == 4
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
assert inputs_to_batch == [
# inputs to sequence step 0
# same as inputs to sequence.batch()
["foo", "bar", "baz", "qux"],
# inputs to sequence step 1
# == outputs of sequence step 0 as no exceptions were raised
["fooa", "bara", "baza", "quxa"],
# inputs to sequence step 2
# 'bar' was dropped as it raised an exception in step 1
["fooaa", "bazaa", "quxaa"],
# inputs to sequence step 3
# 'baz' was dropped as it raised an exception in step 2
["fooaaa", "quxaaa"],
]
parent_runs = sorted(
(r for r in tracer.runs if r.parent_run_id is None),
key=lambda run: inputs.index(run.inputs["input"]),
)
assert len(parent_runs) == 4
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
None,
None,
None,
repr(ValueError()),
]
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
None,
None,
repr(ValueError()),
]
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
assert parent_run_qux.error is None
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
class ControlledExceptionRunnable(Runnable[str, str]):
def __init__(self, fail_starts_with: str) -> None:
self.fail_starts_with = fail_starts_with
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
raise NotImplementedError()
async def _abatch(
self,
inputs: List[str],
) -> List:
outputs: List[Any] = []
for input in inputs:
if input.startswith(self.fail_starts_with):
outputs.append(ValueError())
else:
outputs.append(input + "a")
return outputs
async def abatch(
self,
inputs: List[str],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
return await self._abatch_with_config(
self._abatch,
inputs,
config,
return_exceptions=return_exceptions,
**kwargs,
)
chain = (
ControlledExceptionRunnable("bux")
| ControlledExceptionRunnable("bar")
| ControlledExceptionRunnable("baz")
| ControlledExceptionRunnable("foo")
)
assert isinstance(chain, RunnableSequence)
# Test abatch
with pytest.raises(ValueError):
await chain.abatch(["foo", "bar", "baz", "qux"])
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
tracer = FakeTracer()
inputs = ["foo", "bar", "baz", "qux"]
outputs = await chain.abatch(
inputs, dict(callbacks=[tracer]), return_exceptions=True
)
assert len(outputs) == 4
assert isinstance(outputs[0], ValueError)
assert isinstance(outputs[1], ValueError)
assert isinstance(outputs[2], ValueError)
assert outputs[3] == "quxaaaa"
assert spy.call_count == 4
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
assert inputs_to_batch == [
# inputs to sequence step 0
# same as inputs to sequence.batch()
["foo", "bar", "baz", "qux"],
# inputs to sequence step 1
# == outputs of sequence step 0 as no exceptions were raised
["fooa", "bara", "baza", "quxa"],
# inputs to sequence step 2
# 'bar' was dropped as it raised an exception in step 1
["fooaa", "bazaa", "quxaa"],
# inputs to sequence step 3
# 'baz' was dropped as it raised an exception in step 2
["fooaaa", "quxaaa"],
]
parent_runs = sorted(
(r for r in tracer.runs if r.parent_run_id is None),
key=lambda run: inputs.index(run.inputs["input"]),
)
assert len(parent_runs) == 4
parent_run_foo = parent_runs[0]
assert parent_run_foo.inputs["input"] == "foo"
assert parent_run_foo.error == repr(ValueError())
assert len(parent_run_foo.child_runs) == 4
assert [r.error for r in parent_run_foo.child_runs] == [
None,
None,
None,
repr(ValueError()),
]
parent_run_bar = parent_runs[1]
assert parent_run_bar.inputs["input"] == "bar"
assert parent_run_bar.error == repr(ValueError())
assert len(parent_run_bar.child_runs) == 2
assert [r.error for r in parent_run_bar.child_runs] == [
None,
repr(ValueError()),
]
parent_run_baz = parent_runs[2]
assert parent_run_baz.inputs["input"] == "baz"
assert parent_run_baz.error == repr(ValueError())
assert len(parent_run_baz.child_runs) == 3
assert [r.error for r in parent_run_baz.child_runs] == [
None,
None,
repr(ValueError()),
]
parent_run_qux = parent_runs[3]
assert parent_run_qux.inputs["input"] == "qux"
assert parent_run_qux.error is None
assert parent_run_qux.outputs["output"] == "quxaaaa"
assert len(parent_run_qux.child_runs) == 4
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]

Loading…
Cancel
Save