mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Re-implement retry, adding a root run, and implement return_exception for batch() and abatch()
This commit is contained in:
parent
0eba80912f
commit
4c0e1e501c
@ -263,12 +263,15 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
config = get_config_list(config, len(inputs))
|
||||
max_concurrency = config[0].get("max_concurrency")
|
||||
|
||||
if max_concurrency is None:
|
||||
try:
|
||||
llm_result = self.generate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
@ -277,6 +280,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
batches = [
|
||||
inputs[i : i + max_concurrency]
|
||||
@ -285,25 +293,30 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return [
|
||||
output
|
||||
for batch in batches
|
||||
for output in self.batch(batch, config=config, **kwargs)
|
||||
for output in self.batch(
|
||||
batch, config=config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
]
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[LanguageModelInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async batch, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.batch, inputs, config, max_concurrency
|
||||
None, partial(self.batch, **kwargs), inputs, config
|
||||
)
|
||||
|
||||
config = get_config_list(config, len(inputs))
|
||||
max_concurrency = config[0].get("max_concurrency")
|
||||
|
||||
if max_concurrency is None:
|
||||
try:
|
||||
llm_result = await self.agenerate_prompt(
|
||||
[self._convert_input(input) for input in inputs],
|
||||
callbacks=[c.get("callbacks") for c in config],
|
||||
@ -312,6 +325,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
**kwargs,
|
||||
)
|
||||
return [g[0].text for g in llm_result.generations]
|
||||
except Exception as e:
|
||||
if return_exceptions:
|
||||
return cast(List[str], [e for _ in inputs])
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
batches = [
|
||||
inputs[i : i + max_concurrency]
|
||||
|
@ -27,8 +27,6 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from tenacity import BaseRetrying
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -107,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""
|
||||
@ -115,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return self.invoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return self.invoke(input, config, **kwargs)
|
||||
|
||||
# If there's only one input, don't bother with the executor
|
||||
if len(inputs) == 1:
|
||||
return [self.invoke(inputs[0], configs[0], **kwargs)]
|
||||
return cast(List[Output], [invoke(inputs[0], configs[0])])
|
||||
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return list(executor.map(partial(self.invoke, **kwargs), inputs, configs))
|
||||
return cast(List[Output], list(executor.map(invoke, inputs, configs)))
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
"""
|
||||
@ -133,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
configs = get_config_list(config, len(inputs))
|
||||
coros = map(partial(self.ainvoke, **kwargs), inputs, configs)
|
||||
|
||||
async def ainvoke(
|
||||
input: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await self.ainvoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await self.ainvoke(input, config, **kwargs)
|
||||
|
||||
coros = map(ainvoke, inputs, configs)
|
||||
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)
|
||||
|
||||
def stream(
|
||||
@ -230,11 +252,21 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def with_retry(
|
||||
self,
|
||||
retry: BaseRetrying,
|
||||
*,
|
||||
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,),
|
||||
wait_exponential_jitter: bool = True,
|
||||
stop_after_attempt: int = 3,
|
||||
) -> Runnable[Input, Output]:
|
||||
from langchain.schema.runnable.retry import RunnableRetry
|
||||
|
||||
return RunnableRetry(bound=self, retry=retry, kwargs={}, config={})
|
||||
return RunnableRetry(
|
||||
bound=self,
|
||||
kwargs={},
|
||||
config={},
|
||||
retry_if_exception_type=retry_if_exception_type,
|
||||
wait_exponential_jitter=wait_exponential_jitter,
|
||||
stop_after_attempt=stop_after_attempt,
|
||||
)
|
||||
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
@ -341,6 +373,146 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
await run_manager.on_chain_end(dumpd(output))
|
||||
return output
|
||||
|
||||
def _batch_with_config(
|
||||
self,
|
||||
func: Union[
|
||||
Callable[[List[Input]], List[Union[Exception, Output]]],
|
||||
Callable[
|
||||
[List[Input], List[CallbackManagerForChainRun]],
|
||||
List[Union[Exception, Output]],
|
||||
],
|
||||
Callable[
|
||||
[List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]],
|
||||
List[Union[Exception, Output]],
|
||||
],
|
||||
],
|
||||
input: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
configs = get_config_list(config, len(input))
|
||||
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||
run_managers = [
|
||||
callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
)
|
||||
]
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = func(
|
||||
input,
|
||||
run_manager=run_managers,
|
||||
config=configs,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = func(input, run_manager=run_managers) # type: ignore[call-arg]
|
||||
else:
|
||||
output = func(input) # type: ignore[call-arg]
|
||||
|
||||
print("output", output)
|
||||
except Exception as e:
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_chain_error(e)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in input])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
for run_manager, out in zip(run_managers, output):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
run_manager.on_chain_error(out)
|
||||
else:
|
||||
run_manager.on_chain_end(dumpd(out))
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], output)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
async def _abatch_with_config(
|
||||
self,
|
||||
func: Union[
|
||||
Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]],
|
||||
Callable[
|
||||
[List[Input], List[AsyncCallbackManagerForChainRun]],
|
||||
Awaitable[List[Union[Exception, Output]]],
|
||||
],
|
||||
Callable[
|
||||
[
|
||||
List[Input],
|
||||
List[AsyncCallbackManagerForChainRun],
|
||||
List[RunnableConfig],
|
||||
],
|
||||
Awaitable[List[Union[Exception, Output]]],
|
||||
],
|
||||
],
|
||||
input: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
return_exceptions: bool = False,
|
||||
run_type: Optional[str] = None,
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
configs = get_config_list(config, len(input))
|
||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
*(
|
||||
callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
)
|
||||
)
|
||||
)
|
||||
try:
|
||||
if accepts_run_manager_and_config(func):
|
||||
output = await func(
|
||||
input,
|
||||
run_manager=run_managers,
|
||||
config=configs,
|
||||
) # type: ignore[call-arg]
|
||||
elif accepts_run_manager(func):
|
||||
output = await func(input, run_manager=run_managers) # type: ignore
|
||||
else:
|
||||
output = await func(input) # type: ignore[call-arg]
|
||||
print("output", output)
|
||||
except Exception as e:
|
||||
await asyncio.gather(
|
||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||
)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in input])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
first_exception: Optional[Exception] = None
|
||||
coros: List[Awaitable[None]] = []
|
||||
for run_manager, out in zip(run_managers, output):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
coros.append(run_manager.on_chain_error(out))
|
||||
else:
|
||||
coros.append(run_manager.on_chain_end(dumpd(out)))
|
||||
await asyncio.gather(*coros)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], output)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
def _transform_stream_with_config(
|
||||
self,
|
||||
input: Iterator[Input],
|
||||
@ -596,10 +768,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
@ -656,10 +833,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
@ -841,6 +1023,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
@ -871,6 +1055,53 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
|
||||
# invoke
|
||||
try:
|
||||
if return_exceptions:
|
||||
# Track which inputs (by index) failed so far
|
||||
# If an input has failed it will be present in this map,
|
||||
# and the value will be the exception that was raised.
|
||||
failed_inputs_map: Dict[int, Exception] = {}
|
||||
stepidx = -1
|
||||
for step in self.steps:
|
||||
stepidx += 1
|
||||
# Assemble the original indexes of the remaining inputs
|
||||
# (i.e. the ones that haven't failed yet)
|
||||
remaining_idxs = [
|
||||
i for i in range(len(configs)) if i not in failed_inputs_map
|
||||
]
|
||||
# Invoke the step on the remaining inputs
|
||||
inputs = step.batch(
|
||||
[
|
||||
inp
|
||||
for i, inp in zip(remaining_idxs, inputs)
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for i, (rm, config) in enumerate(zip(run_managers, configs))
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
# If an input failed, add it to the map
|
||||
for i, inp in zip(remaining_idxs, inputs):
|
||||
if isinstance(inp, Exception):
|
||||
failed_inputs_map[i] = inp
|
||||
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||
# If all inputs have failed, stop processing
|
||||
if len(failed_inputs_map) == len(configs):
|
||||
break
|
||||
|
||||
# Reassemble the outputs, inserting Exceptions for failed inputs
|
||||
inputs_copy = inputs.copy()
|
||||
inputs = []
|
||||
for i in range(len(configs)):
|
||||
if i in failed_inputs_map:
|
||||
inputs.append(cast(Input, failed_inputs_map[i]))
|
||||
else:
|
||||
inputs.append(inputs_copy.pop(0))
|
||||
else:
|
||||
for step in self.steps:
|
||||
inputs = step.batch(
|
||||
inputs,
|
||||
@ -880,20 +1111,34 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
)
|
||||
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
for rm in run_managers:
|
||||
rm.on_chain_error(e)
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in inputs])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
for rm, input in zip(run_managers, inputs):
|
||||
rm.on_chain_end(input)
|
||||
first_exception: Optional[Exception] = None
|
||||
for run_manager, out in zip(run_managers, inputs):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
run_manager.on_chain_error(out)
|
||||
else:
|
||||
run_manager.on_chain_end(dumpd(out))
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], inputs)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import (
|
||||
@ -929,6 +1174,53 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
# invoke .batch() on each step
|
||||
# this uses batching optimizations in Runnable subclasses, like LLM
|
||||
try:
|
||||
if return_exceptions:
|
||||
# Track which inputs (by index) failed so far
|
||||
# If an input has failed it will be present in this map,
|
||||
# and the value will be the exception that was raised.
|
||||
failed_inputs_map: Dict[int, Exception] = {}
|
||||
stepidx = -1
|
||||
for step in self.steps:
|
||||
stepidx += 1
|
||||
# Assemble the original indexes of the remaining inputs
|
||||
# (i.e. the ones that haven't failed yet)
|
||||
remaining_idxs = [
|
||||
i for i in range(len(configs)) if i not in failed_inputs_map
|
||||
]
|
||||
# Invoke the step on the remaining inputs
|
||||
inputs = await step.abatch(
|
||||
[
|
||||
inp
|
||||
for i, inp in zip(remaining_idxs, inputs)
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
patch_config(config, callbacks=rm.get_child())
|
||||
for i, (rm, config) in enumerate(zip(run_managers, configs))
|
||||
if i not in failed_inputs_map
|
||||
],
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
# If an input failed, add it to the map
|
||||
for i, inp in zip(remaining_idxs, inputs):
|
||||
if isinstance(inp, Exception):
|
||||
failed_inputs_map[i] = inp
|
||||
inputs = [inp for inp in inputs if not isinstance(inp, Exception)]
|
||||
# If all inputs have failed, stop processing
|
||||
if len(failed_inputs_map) == len(configs):
|
||||
break
|
||||
|
||||
# Reassemble the outputs, inserting Exceptions for failed inputs
|
||||
inputs_copy = inputs.copy()
|
||||
inputs = []
|
||||
for i in range(len(configs)):
|
||||
if i in failed_inputs_map:
|
||||
inputs.append(cast(Input, failed_inputs_map[i]))
|
||||
else:
|
||||
inputs.append(inputs_copy.pop(0))
|
||||
else:
|
||||
for step in self.steps:
|
||||
inputs = await step.abatch(
|
||||
inputs,
|
||||
@ -941,12 +1233,24 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
# finish the root runs
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||
if return_exceptions:
|
||||
return cast(List[Output], [e for _ in inputs])
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
await asyncio.gather(
|
||||
*(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs))
|
||||
)
|
||||
first_exception: Optional[Exception] = None
|
||||
coros: List[Awaitable[None]] = []
|
||||
for run_manager, out in zip(run_managers, inputs):
|
||||
if isinstance(out, Exception):
|
||||
first_exception = first_exception or out
|
||||
coros.append(run_manager.on_chain_error(out))
|
||||
else:
|
||||
coros.append(run_manager.on_chain_end(dumpd(out)))
|
||||
await asyncio.gather(*coros)
|
||||
if return_exceptions or first_exception is None:
|
||||
return cast(List[Output], inputs)
|
||||
else:
|
||||
raise first_exception
|
||||
|
||||
def stream(
|
||||
self,
|
||||
@ -1555,9 +1859,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
config={**self.config, **(config or {}), **kwargs},
|
||||
)
|
||||
|
||||
def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]:
|
||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound.with_retry(retry),
|
||||
bound=self.bound.with_retry(**kwargs),
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
)
|
||||
@ -1590,6 +1894,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if isinstance(config, list):
|
||||
@ -1601,12 +1907,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||
return self.bound.batch(
|
||||
inputs,
|
||||
configs,
|
||||
return_exceptions=return_exceptions,
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if isinstance(config, list):
|
||||
@ -1618,7 +1931,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
patch_config(self._merge_config(config), deep_copy_locals=True)
|
||||
for _ in range(len(inputs))
|
||||
]
|
||||
return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs})
|
||||
return await self.bound.abatch(
|
||||
inputs,
|
||||
configs,
|
||||
return_exceptions=return_exceptions,
|
||||
**{**self.kwargs, **kwargs},
|
||||
)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
|
@ -1,97 +1,113 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
|
||||
|
||||
from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying
|
||||
from tenacity import (
|
||||
AsyncRetrying,
|
||||
RetryCallState,
|
||||
RetryError,
|
||||
Retrying,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManager as AsyncCallbackManagerT,
|
||||
)
|
||||
from langchain.callbacks.manager import CallbackManager as CallbackManagerT
|
||||
|
||||
T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT)
|
||||
else:
|
||||
T = TypeVar("T")
|
||||
T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun)
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
"""Retry a Runnable if it fails."""
|
||||
|
||||
retry: BaseRetrying
|
||||
retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,)
|
||||
|
||||
wait_exponential_jitter: bool = True
|
||||
|
||||
stop_after_attempt: int = 3
|
||||
|
||||
def _kwargs_retrying(self) -> Dict[str, Any]:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if self.retry.sleep is not None:
|
||||
kwargs["sleep"] = self.retry.sleep
|
||||
if self.retry.stop is not None:
|
||||
kwargs["stop"] = self.retry.stop
|
||||
if self.retry.wait is not None:
|
||||
kwargs["wait"] = self.retry.wait
|
||||
if self.retry.retry is not None:
|
||||
kwargs["retry"] = self.retry.retry
|
||||
if self.retry.before is not None:
|
||||
kwargs["before"] = self.retry.before
|
||||
if self.retry.after is not None:
|
||||
kwargs["after"] = self.retry.after
|
||||
if self.retry.before_sleep is not None:
|
||||
kwargs["before_sleep"] = self.retry.before_sleep
|
||||
if self.retry.reraise is not None:
|
||||
kwargs["reraise"] = self.retry.reraise
|
||||
if self.retry.retry_error_cls is not None:
|
||||
kwargs["retry_error_cls"] = self.retry.retry_error_cls
|
||||
if self.retry.retry_error_callback is not None:
|
||||
kwargs["retry_error_callback"] = self.retry.retry_error_callback
|
||||
kwargs: Dict[str, Any] = dict()
|
||||
|
||||
if self.stop_after_attempt:
|
||||
kwargs["stop"] = stop_after_attempt(self.stop_after_attempt)
|
||||
|
||||
if self.wait_exponential_jitter:
|
||||
kwargs["wait"] = wait_exponential_jitter()
|
||||
|
||||
if self.retry_if_exception_type:
|
||||
kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type)
|
||||
|
||||
return kwargs
|
||||
|
||||
def _sync_retrying(self) -> Retrying:
|
||||
return Retrying(**self._kwargs_retrying())
|
||||
def _sync_retrying(self, **kwargs: Any) -> Retrying:
|
||||
return Retrying(**self._kwargs_retrying(), **kwargs)
|
||||
|
||||
def _async_retrying(self) -> AsyncRetrying:
|
||||
return AsyncRetrying(**self._kwargs_retrying())
|
||||
def _async_retrying(self, **kwargs: Any) -> AsyncRetrying:
|
||||
return AsyncRetrying(**self._kwargs_retrying(), **kwargs)
|
||||
|
||||
def _patch_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig],
|
||||
config: RunnableConfig,
|
||||
run_manager: T,
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> RunnableConfig:
|
||||
config = config or {}
|
||||
return (
|
||||
patch_config(
|
||||
return patch_config(
|
||||
config,
|
||||
callbacks=cm_cls.configure(
|
||||
inheritable_callbacks=config.get("callbacks"),
|
||||
local_tags=["retry:attempt:{}".format(retry_state.attempt_number)],
|
||||
),
|
||||
)
|
||||
callbacks=run_manager.get_child(
|
||||
"retry:attempt:{}".format(retry_state.attempt_number)
|
||||
if retry_state.attempt_number > 1
|
||||
else config
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def _patch_config_list(
|
||||
self,
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]],
|
||||
config: List[RunnableConfig],
|
||||
run_manager: List[T],
|
||||
retry_state: RetryCallState,
|
||||
cm_cls: Type[T],
|
||||
) -> Union[RunnableConfig, List[RunnableConfig]]:
|
||||
if isinstance(config, list):
|
||||
return [self._patch_config(c, retry_state, cm_cls) for c in config]
|
||||
) -> List[RunnableConfig]:
|
||||
return [
|
||||
self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager)
|
||||
]
|
||||
|
||||
return self._patch_config(config, retry_state, cm_cls)
|
||||
def _invoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: CallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
for attempt in self._sync_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
input,
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
return self._call_with_config(self._invoke, input, config, **kwargs)
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
async def _ainvoke(
|
||||
self,
|
||||
input: Input,
|
||||
run_manager: AsyncCallbackManagerForChainRun,
|
||||
config: RunnableConfig,
|
||||
) -> Output:
|
||||
async for attempt in self._async_retrying(reraise=True):
|
||||
with attempt:
|
||||
result = super().invoke(
|
||||
result = await super().ainvoke(
|
||||
input,
|
||||
self._patch_config(config, attempt.retry_state, CallbackManager),
|
||||
**kwargs
|
||||
self._patch_config(config, run_manager, attempt.retry_state),
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
@ -100,62 +116,135 @@ class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[CallbackManagerForChainRun],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = await super().ainvoke(
|
||||
input,
|
||||
self._patch_config(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = super().batch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
**kwargs
|
||||
return_exceptions=True,
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
for attempt in self._sync_retrying():
|
||||
with attempt:
|
||||
result = super().batch(
|
||||
inputs,
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, CallbackManager
|
||||
),
|
||||
**kwargs
|
||||
return self._batch_with_config(
|
||||
self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
run_manager: List[AsyncCallbackManagerForChainRun],
|
||||
config: List[RunnableConfig],
|
||||
) -> List[Union[Output, Exception]]:
|
||||
results_map: Dict[int, Output] = {}
|
||||
|
||||
def pending(iterable: List[U]) -> List[U]:
|
||||
return [item for idx, item in enumerate(iterable) if idx not in results_map]
|
||||
|
||||
try:
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
# Get the results of the inputs that have not succeeded yet.
|
||||
result = await super().abatch(
|
||||
pending(inputs),
|
||||
self._patch_config_list(
|
||||
pending(config), pending(run_manager), attempt.retry_state
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
# Register the results of the inputs that have succeeded.
|
||||
first_exception = None
|
||||
for i, r in enumerate(result):
|
||||
if isinstance(r, Exception):
|
||||
if not first_exception:
|
||||
first_exception = r
|
||||
continue
|
||||
results_map[i] = r
|
||||
# If any exception occurred, raise it, to retry the failed ones
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
if (
|
||||
attempt.retry_state.outcome
|
||||
and not attempt.retry_state.outcome.failed
|
||||
):
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
except RetryError as e:
|
||||
try:
|
||||
result
|
||||
except UnboundLocalError:
|
||||
result = cast(List[Output], [e] * len(inputs))
|
||||
|
||||
outputs: List[Union[Output, Exception]] = []
|
||||
for idx, _ in enumerate(inputs):
|
||||
if idx in results_map:
|
||||
outputs.append(results_map[idx])
|
||||
else:
|
||||
outputs.append(result.pop(0))
|
||||
return outputs
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[Input],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import AsyncCallbackManager
|
||||
|
||||
async for attempt in self._async_retrying():
|
||||
with attempt:
|
||||
result = await super().abatch(
|
||||
inputs,
|
||||
self._patch_config_list(
|
||||
config, attempt.retry_state, AsyncCallbackManager
|
||||
),
|
||||
**kwargs
|
||||
return await self._abatch_with_config(
|
||||
self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs
|
||||
)
|
||||
if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed:
|
||||
attempt.retry_state.set_result(result)
|
||||
return result
|
||||
|
||||
# stream() and transform() are not retried because retrying a stream
|
||||
# is not very intuitive.
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
@ -12,6 +11,7 @@ from typing import (
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
@ -23,7 +23,11 @@ from langchain.schema.runnable.base import (
|
||||
RunnableSequence,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig, get_config_list
|
||||
from langchain.schema.runnable.config import (
|
||||
RunnableConfig,
|
||||
get_config_list,
|
||||
get_executor_for_config,
|
||||
)
|
||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||
|
||||
|
||||
@ -122,7 +126,7 @@ class RouterRunnable(
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
@ -130,16 +134,23 @@ class RouterRunnable(
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
def invoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return runnable.invoke(input, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
lambda runnable, input, config: runnable.invoke(input, config),
|
||||
runnables,
|
||||
actual_inputs,
|
||||
configs,
|
||||
)
|
||||
with get_executor_for_config(configs[0]) as executor:
|
||||
return cast(
|
||||
List[Output],
|
||||
list(executor.map(invoke, runnables, actual_inputs, configs)),
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
@ -147,7 +158,7 @@ class RouterRunnable(
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
@ -155,12 +166,23 @@ class RouterRunnable(
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
async def ainvoke(
|
||||
runnable: Runnable, input: Input, config: RunnableConfig
|
||||
) -> Union[Output, Exception]:
|
||||
if return_exceptions:
|
||||
try:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
else:
|
||||
return await runnable.ainvoke(input, config, **kwargs)
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = get_config_list(config, len(inputs))
|
||||
return await gather_with_concurrency(
|
||||
max_concurrency,
|
||||
configs[0].get("max_concurrency"),
|
||||
*(
|
||||
runnable.ainvoke(input, config)
|
||||
ainvoke(runnable, input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
@ -6,7 +6,6 @@ import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from langchain import PromptTemplate
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@ -1444,22 +1443,345 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(RetryError):
|
||||
with pytest.raises(ValueError):
|
||||
runnable.with_retry(
|
||||
Retrying(
|
||||
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
|
||||
)
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).invoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 2
|
||||
assert _lambda_mock.call_count == 2 # retried
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
runnable.with_retry(
|
||||
Retrying(
|
||||
stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,))
|
||||
)
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).invoke(2)
|
||||
|
||||
assert _lambda_mock.call_count == 1 # did not retry
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0])
|
||||
|
||||
# 3rd input isn't retried because it succeeded
|
||||
assert _lambda_mock.call_count == 3 + 2
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
output = runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).batch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
# 3rd input isn't retried because it succeeded
|
||||
assert _lambda_mock.call_count == 3 + 2
|
||||
assert len(output) == 3
|
||||
assert isinstance(output[0], ValueError)
|
||||
assert isinstance(output[1], RuntimeError)
|
||||
assert output[2] == 0
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
def _lambda(x: int) -> Union[int, Runnable]:
|
||||
if x == 1:
|
||||
raise ValueError("x is 1")
|
||||
elif x == 2:
|
||||
raise RuntimeError("x is 2")
|
||||
else:
|
||||
return x
|
||||
|
||||
_lambda_mock = mocker.Mock(side_effect=_lambda)
|
||||
runnable = RunnableLambda(_lambda_mock)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.ainvoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 1
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).ainvoke(1)
|
||||
|
||||
assert _lambda_mock.call_count == 2 # retried
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).ainvoke(2)
|
||||
|
||||
assert _lambda_mock.call_count == 1 # did not retry
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0])
|
||||
|
||||
# 3rd input isn't retried because it succeeded
|
||||
assert _lambda_mock.call_count == 3 + 2
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
output = await runnable.with_retry(
|
||||
stop_after_attempt=2,
|
||||
retry_if_exception_type=(ValueError,),
|
||||
).abatch([1, 2, 0], return_exceptions=True)
|
||||
|
||||
# 3rd input isn't retried because it succeeded
|
||||
assert _lambda_mock.call_count == 3 + 2
|
||||
assert len(output) == 3
|
||||
assert isinstance(output[0], ValueError)
|
||||
assert isinstance(output[1], RuntimeError)
|
||||
assert output[2] == 0
|
||||
_lambda_mock.reset_mock()
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _batch(
|
||||
self,
|
||||
inputs: List[str],
|
||||
) -> List:
|
||||
outputs: List[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
outputs.append(ValueError())
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
return outputs
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[str],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
return self._batch_with_config(
|
||||
self._batch,
|
||||
inputs,
|
||||
config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
chain = (
|
||||
ControlledExceptionRunnable("bux")
|
||||
| ControlledExceptionRunnable("bar")
|
||||
| ControlledExceptionRunnable("baz")
|
||||
| ControlledExceptionRunnable("foo")
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Test batch
|
||||
with pytest.raises(ValueError):
|
||||
chain.batch(["foo", "bar", "baz", "qux"])
|
||||
|
||||
spy = mocker.spy(ControlledExceptionRunnable, "batch")
|
||||
tracer = FakeTracer()
|
||||
inputs = ["foo", "bar", "baz", "qux"]
|
||||
outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True)
|
||||
assert len(outputs) == 4
|
||||
assert isinstance(outputs[0], ValueError)
|
||||
assert isinstance(outputs[1], ValueError)
|
||||
assert isinstance(outputs[2], ValueError)
|
||||
assert outputs[3] == "quxaaaa"
|
||||
assert spy.call_count == 4
|
||||
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
|
||||
assert inputs_to_batch == [
|
||||
# inputs to sequence step 0
|
||||
# same as inputs to sequence.batch()
|
||||
["foo", "bar", "baz", "qux"],
|
||||
# inputs to sequence step 1
|
||||
# == outputs of sequence step 0 as no exceptions were raised
|
||||
["fooa", "bara", "baza", "quxa"],
|
||||
# inputs to sequence step 2
|
||||
# 'bar' was dropped as it raised an exception in step 1
|
||||
["fooaa", "bazaa", "quxaa"],
|
||||
# inputs to sequence step 3
|
||||
# 'baz' was dropped as it raised an exception in step 2
|
||||
["fooaaa", "quxaaa"],
|
||||
]
|
||||
parent_runs = sorted(
|
||||
(r for r in tracer.runs if r.parent_run_id is None),
|
||||
key=lambda run: inputs.index(run.inputs["input"]),
|
||||
)
|
||||
assert len(parent_runs) == 4
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
assert parent_run_qux.error is None
|
||||
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||
assert len(parent_run_qux.child_runs) == 4
|
||||
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
class ControlledExceptionRunnable(Runnable[str, str]):
|
||||
def __init__(self, fail_starts_with: str) -> None:
|
||||
self.fail_starts_with = fail_starts_with
|
||||
|
||||
def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: List[str],
|
||||
) -> List:
|
||||
outputs: List[Any] = []
|
||||
for input in inputs:
|
||||
if input.startswith(self.fail_starts_with):
|
||||
outputs.append(ValueError())
|
||||
else:
|
||||
outputs.append(input + "a")
|
||||
return outputs
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[str],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
return await self._abatch_with_config(
|
||||
self._abatch,
|
||||
inputs,
|
||||
config,
|
||||
return_exceptions=return_exceptions,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
chain = (
|
||||
ControlledExceptionRunnable("bux")
|
||||
| ControlledExceptionRunnable("bar")
|
||||
| ControlledExceptionRunnable("baz")
|
||||
| ControlledExceptionRunnable("foo")
|
||||
)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Test abatch
|
||||
with pytest.raises(ValueError):
|
||||
await chain.abatch(["foo", "bar", "baz", "qux"])
|
||||
|
||||
spy = mocker.spy(ControlledExceptionRunnable, "abatch")
|
||||
tracer = FakeTracer()
|
||||
inputs = ["foo", "bar", "baz", "qux"]
|
||||
outputs = await chain.abatch(
|
||||
inputs, dict(callbacks=[tracer]), return_exceptions=True
|
||||
)
|
||||
assert len(outputs) == 4
|
||||
assert isinstance(outputs[0], ValueError)
|
||||
assert isinstance(outputs[1], ValueError)
|
||||
assert isinstance(outputs[2], ValueError)
|
||||
assert outputs[3] == "quxaaaa"
|
||||
assert spy.call_count == 4
|
||||
inputs_to_batch = [c[0][1] for c in spy.call_args_list]
|
||||
assert inputs_to_batch == [
|
||||
# inputs to sequence step 0
|
||||
# same as inputs to sequence.batch()
|
||||
["foo", "bar", "baz", "qux"],
|
||||
# inputs to sequence step 1
|
||||
# == outputs of sequence step 0 as no exceptions were raised
|
||||
["fooa", "bara", "baza", "quxa"],
|
||||
# inputs to sequence step 2
|
||||
# 'bar' was dropped as it raised an exception in step 1
|
||||
["fooaa", "bazaa", "quxaa"],
|
||||
# inputs to sequence step 3
|
||||
# 'baz' was dropped as it raised an exception in step 2
|
||||
["fooaaa", "quxaaa"],
|
||||
]
|
||||
parent_runs = sorted(
|
||||
(r for r in tracer.runs if r.parent_run_id is None),
|
||||
key=lambda run: inputs.index(run.inputs["input"]),
|
||||
)
|
||||
assert len(parent_runs) == 4
|
||||
|
||||
parent_run_foo = parent_runs[0]
|
||||
assert parent_run_foo.inputs["input"] == "foo"
|
||||
assert parent_run_foo.error == repr(ValueError())
|
||||
assert len(parent_run_foo.child_runs) == 4
|
||||
assert [r.error for r in parent_run_foo.child_runs] == [
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_bar = parent_runs[1]
|
||||
assert parent_run_bar.inputs["input"] == "bar"
|
||||
assert parent_run_bar.error == repr(ValueError())
|
||||
assert len(parent_run_bar.child_runs) == 2
|
||||
assert [r.error for r in parent_run_bar.child_runs] == [
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_baz = parent_runs[2]
|
||||
assert parent_run_baz.inputs["input"] == "baz"
|
||||
assert parent_run_baz.error == repr(ValueError())
|
||||
assert len(parent_run_baz.child_runs) == 3
|
||||
assert [r.error for r in parent_run_baz.child_runs] == [
|
||||
None,
|
||||
None,
|
||||
repr(ValueError()),
|
||||
]
|
||||
|
||||
parent_run_qux = parent_runs[3]
|
||||
assert parent_run_qux.inputs["input"] == "qux"
|
||||
assert parent_run_qux.error is None
|
||||
assert parent_run_qux.outputs["output"] == "quxaaaa"
|
||||
assert len(parent_run_qux.child_runs) == 4
|
||||
assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]
|
||||
|
Loading…
Reference in New Issue
Block a user