diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 6f7dcc2008..3724db869f 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -263,20 +263,28 @@ class BaseLLM(BaseLanguageModel[str], ABC): self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - max_concurrency: Optional[int] = None, + *, + return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: - llm_result = self.generate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] + try: + llm_result = self.generate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e else: batches = [ inputs[i : i + max_concurrency] @@ -285,33 +293,43 @@ class BaseLLM(BaseLanguageModel[str], ABC): return [ output for batch in batches - for output in self.batch(batch, config=config, **kwargs) + for output in self.batch( + batch, config=config, return_exceptions=return_exceptions, **kwargs + ) ] async def abatch( self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, - max_concurrency: Optional[int] = None, + *, + return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: if type(self)._agenerate == BaseLLM._agenerate: # model doesn't implement async batch, so use default implementation return await asyncio.get_running_loop().run_in_executor( - None, self.batch, inputs, config, max_concurrency + None, partial(self.batch, **kwargs), inputs, config ) config = get_config_list(config, len(inputs)) + max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: - llm_result = await self.agenerate_prompt( - [self._convert_input(input) for input in inputs], - callbacks=[c.get("callbacks") for c in config], - tags=[c.get("tags") for c in config], - metadata=[c.get("metadata") for c in config], - **kwargs, - ) - return [g[0].text for g in llm_result.generations] + try: + llm_result = await self.agenerate_prompt( + [self._convert_input(input) for input in inputs], + callbacks=[c.get("callbacks") for c in config], + tags=[c.get("tags") for c in config], + metadata=[c.get("metadata") for c in config], + **kwargs, + ) + return [g[0].text for g in llm_result.generations] + except Exception as e: + if return_exceptions: + return cast(List[str], [e for _ in inputs]) + else: + raise e else: batches = [ inputs[i : i + max_concurrency] diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index ad5c8cfe84..7d65f0d461 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -27,8 +27,6 @@ from typing import ( cast, ) -from tenacity import BaseRetrying - if TYPE_CHECKING: from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -107,6 +105,8 @@ class Runnable(Generic[Input, Output], ABC): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -115,17 +115,28 @@ class Runnable(Generic[Input, Output], ABC): """ configs = get_config_list(config, len(inputs)) + def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: + if return_exceptions: + try: + return self.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return self.invoke(input, config, **kwargs) + # If there's only one input, don't bother with the executor if len(inputs) == 1: - return [self.invoke(inputs[0], configs[0], **kwargs)] + return cast(List[Output], [invoke(inputs[0], configs[0])]) with get_executor_for_config(configs[0]) as executor: - return list(executor.map(partial(self.invoke, **kwargs), inputs, configs)) + return cast(List[Output], list(executor.map(invoke, inputs, configs))) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: """ @@ -133,8 +144,19 @@ class Runnable(Generic[Input, Output], ABC): Subclasses should override this method if they can batch more efficiently. """ configs = get_config_list(config, len(inputs)) - coros = map(partial(self.ainvoke, **kwargs), inputs, configs) + async def ainvoke( + input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await self.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await self.ainvoke(input, config, **kwargs) + + coros = map(ainvoke, inputs, configs) return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros) def stream( @@ -230,11 +252,21 @@ class Runnable(Generic[Input, Output], ABC): def with_retry( self, - retry: BaseRetrying, + *, + retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,), + wait_exponential_jitter: bool = True, + stop_after_attempt: int = 3, ) -> Runnable[Input, Output]: from langchain.schema.runnable.retry import RunnableRetry - return RunnableRetry(bound=self, retry=retry, kwargs={}, config={}) + return RunnableRetry( + bound=self, + kwargs={}, + config={}, + retry_if_exception_type=retry_if_exception_type, + wait_exponential_jitter=wait_exponential_jitter, + stop_after_attempt=stop_after_attempt, + ) def map(self) -> Runnable[List[Input], List[Output]]: """ @@ -341,6 +373,146 @@ class Runnable(Generic[Input, Output], ABC): await run_manager.on_chain_end(dumpd(output)) return output + def _batch_with_config( + self, + func: Union[ + Callable[[List[Input]], List[Union[Exception, Output]]], + Callable[ + [List[Input], List[CallbackManagerForChainRun]], + List[Union[Exception, Output]], + ], + Callable[ + [List[Input], List[CallbackManagerForChainRun], List[RunnableConfig]], + List[Union[Exception, Output]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + return_exceptions: bool = False, + run_type: Optional[str] = None, + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + configs = get_config_list(config, len(input)) + callback_managers = [get_callback_manager_for_config(c) for c in configs] + run_managers = [ + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ] + try: + if accepts_run_manager_and_config(func): + output = func( + input, + run_manager=run_managers, + config=configs, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = func(input, run_manager=run_managers) # type: ignore[call-arg] + else: + output = func(input) # type: ignore[call-arg] + + print("output", output) + except Exception as e: + for run_manager in run_managers: + run_manager.on_chain_error(e) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + + async def _abatch_with_config( + self, + func: Union[ + Callable[[List[Input]], Awaitable[List[Union[Exception, Output]]]], + Callable[ + [List[Input], List[AsyncCallbackManagerForChainRun]], + Awaitable[List[Union[Exception, Output]]], + ], + Callable[ + [ + List[Input], + List[AsyncCallbackManagerForChainRun], + List[RunnableConfig], + ], + Awaitable[List[Union[Exception, Output]]], + ], + ], + input: List[Input], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + return_exceptions: bool = False, + run_type: Optional[str] = None, + ) -> List[Output]: + """Helper method to transform an Input value to an Output value, + with callbacks. Use this method to implement invoke() in subclasses.""" + configs = get_config_list(config, len(input)) + callback_managers = [get_async_callback_manager_for_config(c) for c in configs] + run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( + *( + callback_manager.on_chain_start( + dumpd(self), + input, + run_type=run_type, + name=config.get("run_name"), + ) + for callback_manager, input, config in zip( + callback_managers, input, configs + ) + ) + ) + try: + if accepts_run_manager_and_config(func): + output = await func( + input, + run_manager=run_managers, + config=configs, + ) # type: ignore[call-arg] + elif accepts_run_manager(func): + output = await func(input, run_manager=run_managers) # type: ignore + else: + output = await func(input) # type: ignore[call-arg] + print("output", output) + except Exception as e: + await asyncio.gather( + *(run_manager.on_chain_error(e) for run_manager in run_managers) + ) + if return_exceptions: + return cast(List[Output], [e for _ in input]) + else: + raise + else: + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, output): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], output) + else: + raise first_exception + def _transform_stream_with_config( self, input: Iterator[Input], @@ -596,10 +768,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager + if return_exceptions: + raise NotImplementedError() + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -656,10 +833,15 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import AsyncCallbackManager + if return_exceptions: + raise NotImplementedError() + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -841,6 +1023,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import CallbackManager @@ -871,29 +1055,90 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke try: - for step in self.steps: - inputs = step.batch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - ) + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + stepidx = -1 + for step in self.steps: + stepidx += 1 + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = step.batch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for step in self.steps: + inputs = step.batch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + ) + # finish the root runs except (KeyboardInterrupt, Exception) as e: for rm in run_managers: rm.on_chain_error(e) - raise + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise else: - for rm, input in zip(run_managers, inputs): - rm.on_chain_end(input) - return cast(List[Output], inputs) + first_exception: Optional[Exception] = None + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + run_manager.on_chain_error(out) + else: + run_manager.on_chain_end(dumpd(out)) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: from langchain.callbacks.manager import ( @@ -929,24 +1174,83 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): # invoke .batch() on each step # this uses batching optimizations in Runnable subclasses, like LLM try: - for step in self.steps: - inputs = await step.abatch( - inputs, - [ - # each step a child run of the corresponding root run - patch_config(config, callbacks=rm.get_child()) - for rm, config in zip(run_managers, configs) - ], - ) + if return_exceptions: + # Track which inputs (by index) failed so far + # If an input has failed it will be present in this map, + # and the value will be the exception that was raised. + failed_inputs_map: Dict[int, Exception] = {} + stepidx = -1 + for step in self.steps: + stepidx += 1 + # Assemble the original indexes of the remaining inputs + # (i.e. the ones that haven't failed yet) + remaining_idxs = [ + i for i in range(len(configs)) if i not in failed_inputs_map + ] + # Invoke the step on the remaining inputs + inputs = await step.abatch( + [ + inp + for i, inp in zip(remaining_idxs, inputs) + if i not in failed_inputs_map + ], + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for i, (rm, config) in enumerate(zip(run_managers, configs)) + if i not in failed_inputs_map + ], + return_exceptions=return_exceptions, + **kwargs, + ) + # If an input failed, add it to the map + for i, inp in zip(remaining_idxs, inputs): + if isinstance(inp, Exception): + failed_inputs_map[i] = inp + inputs = [inp for inp in inputs if not isinstance(inp, Exception)] + # If all inputs have failed, stop processing + if len(failed_inputs_map) == len(configs): + break + + # Reassemble the outputs, inserting Exceptions for failed inputs + inputs_copy = inputs.copy() + inputs = [] + for i in range(len(configs)): + if i in failed_inputs_map: + inputs.append(cast(Input, failed_inputs_map[i])) + else: + inputs.append(inputs_copy.pop(0)) + else: + for step in self.steps: + inputs = await step.abatch( + inputs, + [ + # each step a child run of the corresponding root run + patch_config(config, callbacks=rm.get_child()) + for rm, config in zip(run_managers, configs) + ], + ) # finish the root runs except (KeyboardInterrupt, Exception) as e: await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) - raise + if return_exceptions: + return cast(List[Output], [e for _ in inputs]) + else: + raise else: - await asyncio.gather( - *(rm.on_chain_end(input) for rm, input in zip(run_managers, inputs)) - ) - return cast(List[Output], inputs) + first_exception: Optional[Exception] = None + coros: List[Awaitable[None]] = [] + for run_manager, out in zip(run_managers, inputs): + if isinstance(out, Exception): + first_exception = first_exception or out + coros.append(run_manager.on_chain_error(out)) + else: + coros.append(run_manager.on_chain_end(dumpd(out))) + await asyncio.gather(*coros) + if return_exceptions or first_exception is None: + return cast(List[Output], inputs) + else: + raise first_exception def stream( self, @@ -1555,9 +1859,9 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): config={**self.config, **(config or {}), **kwargs}, ) - def with_retry(self, retry: BaseRetrying) -> Runnable[Input, Output]: + def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( - bound=self.bound.with_retry(retry), + bound=self.bound.with_retry(**kwargs), kwargs=self.kwargs, config=self.config, ) @@ -1590,6 +1894,8 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: if isinstance(config, list): @@ -1601,12 +1907,19 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ] - return self.bound.batch(inputs, configs, **{**self.kwargs, **kwargs}) + return self.bound.batch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: if isinstance(config, list): @@ -1618,7 +1931,12 @@ class RunnableBinding(Serializable, Runnable[Input, Output]): patch_config(self._merge_config(config), deep_copy_locals=True) for _ in range(len(inputs)) ] - return await self.bound.abatch(inputs, configs, **{**self.kwargs, **kwargs}) + return await self.bound.abatch( + inputs, + configs, + return_exceptions=return_exceptions, + **{**self.kwargs, **kwargs}, + ) def stream( self, diff --git a/libs/langchain/langchain/schema/runnable/retry.py b/libs/langchain/langchain/schema/runnable/retry.py index ce67cca399..cda41605aa 100644 --- a/libs/langchain/langchain/schema/runnable/retry.py +++ b/libs/langchain/langchain/schema/runnable/retry.py @@ -1,97 +1,113 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast -from tenacity import AsyncRetrying, BaseRetrying, RetryCallState, Retrying +from tenacity import ( + AsyncRetrying, + RetryCallState, + RetryError, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_exponential_jitter, +) +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.schema.runnable.base import Input, Output, RunnableBinding from langchain.schema.runnable.config import RunnableConfig, patch_config -if TYPE_CHECKING: - from langchain.callbacks.manager import ( - AsyncCallbackManager as AsyncCallbackManagerT, - ) - from langchain.callbacks.manager import CallbackManager as CallbackManagerT - - T = TypeVar("T", CallbackManagerT, AsyncCallbackManagerT) -else: - T = TypeVar("T") +T = TypeVar("T", CallbackManagerForChainRun, AsyncCallbackManagerForChainRun) +U = TypeVar("U") class RunnableRetry(RunnableBinding[Input, Output]): """Retry a Runnable if it fails.""" - retry: BaseRetrying + retry_if_exception_type: Tuple[Type[BaseException]] = (Exception,) + + wait_exponential_jitter: bool = True + + stop_after_attempt: int = 3 def _kwargs_retrying(self) -> Dict[str, Any]: - kwargs: Dict[str, Any] = {} - if self.retry.sleep is not None: - kwargs["sleep"] = self.retry.sleep - if self.retry.stop is not None: - kwargs["stop"] = self.retry.stop - if self.retry.wait is not None: - kwargs["wait"] = self.retry.wait - if self.retry.retry is not None: - kwargs["retry"] = self.retry.retry - if self.retry.before is not None: - kwargs["before"] = self.retry.before - if self.retry.after is not None: - kwargs["after"] = self.retry.after - if self.retry.before_sleep is not None: - kwargs["before_sleep"] = self.retry.before_sleep - if self.retry.reraise is not None: - kwargs["reraise"] = self.retry.reraise - if self.retry.retry_error_cls is not None: - kwargs["retry_error_cls"] = self.retry.retry_error_cls - if self.retry.retry_error_callback is not None: - kwargs["retry_error_callback"] = self.retry.retry_error_callback + kwargs: Dict[str, Any] = dict() + + if self.stop_after_attempt: + kwargs["stop"] = stop_after_attempt(self.stop_after_attempt) + + if self.wait_exponential_jitter: + kwargs["wait"] = wait_exponential_jitter() + + if self.retry_if_exception_type: + kwargs["retry"] = retry_if_exception_type(self.retry_if_exception_type) + return kwargs - def _sync_retrying(self) -> Retrying: - return Retrying(**self._kwargs_retrying()) + def _sync_retrying(self, **kwargs: Any) -> Retrying: + return Retrying(**self._kwargs_retrying(), **kwargs) - def _async_retrying(self) -> AsyncRetrying: - return AsyncRetrying(**self._kwargs_retrying()) + def _async_retrying(self, **kwargs: Any) -> AsyncRetrying: + return AsyncRetrying(**self._kwargs_retrying(), **kwargs) def _patch_config( self, - config: Optional[RunnableConfig], + config: RunnableConfig, + run_manager: T, retry_state: RetryCallState, - cm_cls: Type[T], ) -> RunnableConfig: config = config or {} - return ( - patch_config( - config, - callbacks=cm_cls.configure( - inheritable_callbacks=config.get("callbacks"), - local_tags=["retry:attempt:{}".format(retry_state.attempt_number)], - ), - ) - if retry_state.attempt_number > 1 - else config + return patch_config( + config, + callbacks=run_manager.get_child( + "retry:attempt:{}".format(retry_state.attempt_number) + if retry_state.attempt_number > 1 + else None + ), ) def _patch_config_list( self, - config: Optional[Union[RunnableConfig, List[RunnableConfig]]], + config: List[RunnableConfig], + run_manager: List[T], retry_state: RetryCallState, - cm_cls: Type[T], - ) -> Union[RunnableConfig, List[RunnableConfig]]: - if isinstance(config, list): - return [self._patch_config(c, retry_state, cm_cls) for c in config] + ) -> List[RunnableConfig]: + return [ + self._patch_config(c, rm, retry_state) for c, rm in zip(config, run_manager) + ] - return self._patch_config(config, retry_state, cm_cls) + def _invoke( + self, + input: Input, + run_manager: CallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + for attempt in self._sync_retrying(reraise=True): + with attempt: + result = super().invoke( + input, + self._patch_config(config, run_manager, attempt.retry_state), + ) + if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: + attempt.retry_state.set_result(result) + return result def invoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain.callbacks.manager import CallbackManager + return self._call_with_config(self._invoke, input, config, **kwargs) - for attempt in self._sync_retrying(): + async def _ainvoke( + self, + input: Input, + run_manager: AsyncCallbackManagerForChainRun, + config: RunnableConfig, + ) -> Output: + async for attempt in self._async_retrying(reraise=True): with attempt: - result = super().invoke( + result = await super().ainvoke( input, - self._patch_config(config, attempt.retry_state, CallbackManager), - **kwargs + self._patch_config(config, run_manager, attempt.retry_state), ) if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: attempt.retry_state.set_result(result) @@ -100,62 +116,135 @@ class RunnableRetry(RunnableBinding[Input, Output]): async def ainvoke( self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: - from langchain.callbacks.manager import AsyncCallbackManager + return await self._acall_with_config(self._ainvoke, input, config, **kwargs) - async for attempt in self._async_retrying(): - with attempt: - result = await super().ainvoke( - input, - self._patch_config( - config, attempt.retry_state, AsyncCallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + def _batch( + self, + inputs: List[Input], + run_manager: List[CallbackManagerForChainRun], + config: List[RunnableConfig], + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + for attempt in self._sync_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = super().batch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs def batch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Any ) -> List[Output]: - from langchain.callbacks.manager import CallbackManager + return self._batch_with_config( + self._batch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) - for attempt in self._sync_retrying(): - with attempt: - result = super().batch( - inputs, - self._patch_config_list( - config, attempt.retry_state, CallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + async def _abatch( + self, + inputs: List[Input], + run_manager: List[AsyncCallbackManagerForChainRun], + config: List[RunnableConfig], + ) -> List[Union[Output, Exception]]: + results_map: Dict[int, Output] = {} + + def pending(iterable: List[U]) -> List[U]: + return [item for idx, item in enumerate(iterable) if idx not in results_map] + + try: + async for attempt in self._async_retrying(): + with attempt: + # Get the results of the inputs that have not succeeded yet. + result = await super().abatch( + pending(inputs), + self._patch_config_list( + pending(config), pending(run_manager), attempt.retry_state + ), + return_exceptions=True, + ) + # Register the results of the inputs that have succeeded. + first_exception = None + for i, r in enumerate(result): + if isinstance(r, Exception): + if not first_exception: + first_exception = r + continue + results_map[i] = r + # If any exception occurred, raise it, to retry the failed ones + if first_exception: + raise first_exception + if ( + attempt.retry_state.outcome + and not attempt.retry_state.outcome.failed + ): + attempt.retry_state.set_result(result) + except RetryError as e: + try: + result + except UnboundLocalError: + result = cast(List[Output], [e] * len(inputs)) + + outputs: List[Union[Output, Exception]] = [] + for idx, _ in enumerate(inputs): + if idx in results_map: + outputs.append(results_map[idx]) + else: + outputs.append(result.pop(0)) + return outputs async def abatch( self, inputs: List[Input], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, **kwargs: Any ) -> List[Output]: - from langchain.callbacks.manager import AsyncCallbackManager - - async for attempt in self._async_retrying(): - with attempt: - result = await super().abatch( - inputs, - self._patch_config_list( - config, attempt.retry_state, AsyncCallbackManager - ), - **kwargs - ) - if attempt.retry_state.outcome and not attempt.retry_state.outcome.failed: - attempt.retry_state.set_result(result) - return result + return await self._abatch_with_config( + self._abatch, inputs, config, return_exceptions=return_exceptions, **kwargs + ) # stream() and transform() are not retried because retrying a stream # is not very intuitive. diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 5277932543..a51d0907ca 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -1,6 +1,5 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor from typing import ( Any, AsyncIterator, @@ -12,6 +11,7 @@ from typing import ( Optional, TypedDict, Union, + cast, ) from langchain.load.serializable import Serializable @@ -23,7 +23,11 @@ from langchain.schema.runnable.base import ( RunnableSequence, coerce_to_runnable, ) -from langchain.schema.runnable.config import RunnableConfig, get_config_list +from langchain.schema.runnable.config import ( + RunnableConfig, + get_config_list, + get_executor_for_config, +) from langchain.schema.runnable.utils import gather_with_concurrency @@ -122,7 +126,7 @@ class RouterRunnable( inputs: List[RouterInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, - max_concurrency: Optional[int] = None, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] @@ -130,16 +134,23 @@ class RouterRunnable( if any(key not in self.runnables for key in keys): raise ValueError("One or more keys do not have a corresponding runnable") + def invoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return runnable.invoke(input, config, **kwargs) + except Exception as e: + return e + else: + return runnable.invoke(input, config, **kwargs) + runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: - return list( - executor.map( - lambda runnable, input, config: runnable.invoke(input, config), - runnables, - actual_inputs, - configs, - ) + with get_executor_for_config(configs[0]) as executor: + return cast( + List[Output], + list(executor.map(invoke, runnables, actual_inputs, configs)), ) async def abatch( @@ -147,7 +158,7 @@ class RouterRunnable( inputs: List[RouterInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, - max_concurrency: Optional[int] = None, + return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: keys = [input["key"] for input in inputs] @@ -155,12 +166,23 @@ class RouterRunnable( if any(key not in self.runnables for key in keys): raise ValueError("One or more keys do not have a corresponding runnable") + async def ainvoke( + runnable: Runnable, input: Input, config: RunnableConfig + ) -> Union[Output, Exception]: + if return_exceptions: + try: + return await runnable.ainvoke(input, config, **kwargs) + except Exception as e: + return e + else: + return await runnable.ainvoke(input, config, **kwargs) + runnables = [self.runnables[key] for key in keys] configs = get_config_list(config, len(inputs)) return await gather_with_concurrency( - max_concurrency, + configs[0].get("max_concurrency"), *( - runnable.ainvoke(input, config) + ainvoke(runnable, input, config) for runnable, input, config in zip(runnables, actual_inputs, configs) ), ) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index ca22c6c984..2e0be35ddc 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -6,7 +6,6 @@ import pytest from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion -from tenacity import RetryError, Retrying, retry_if_exception_type, stop_after_attempt from langchain import PromptTemplate from langchain.callbacks.manager import Callbacks @@ -1444,22 +1443,345 @@ def test_retrying(mocker: MockerFixture) -> None: assert _lambda_mock.call_count == 1 _lambda_mock.reset_mock() - with pytest.raises(RetryError): + with pytest.raises(ValueError): runnable.with_retry( - Retrying( - stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) - ) + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), ).invoke(1) - assert _lambda_mock.call_count == 2 + assert _lambda_mock.call_count == 2 # retried _lambda_mock.reset_mock() with pytest.raises(RuntimeError): runnable.with_retry( - Retrying( - stop=stop_after_attempt(2), retry=retry_if_exception_type((ValueError,)) - ) + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), ).invoke(2) + assert _lambda_mock.call_count == 1 # did not retry + _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).batch([1, 2, 0]) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + _lambda_mock.reset_mock() + + output = runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).batch([1, 2, 0], return_exceptions=True) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + assert len(output) == 3 + assert isinstance(output[0], ValueError) + assert isinstance(output[1], RuntimeError) + assert output[2] == 0 + _lambda_mock.reset_mock() + + +@pytest.mark.asyncio +async def test_async_retrying(mocker: MockerFixture) -> None: + def _lambda(x: int) -> Union[int, Runnable]: + if x == 1: + raise ValueError("x is 1") + elif x == 2: + raise RuntimeError("x is 2") + else: + return x + + _lambda_mock = mocker.Mock(side_effect=_lambda) + runnable = RunnableLambda(_lambda_mock) + + with pytest.raises(ValueError): + await runnable.ainvoke(1) + assert _lambda_mock.call_count == 1 _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).ainvoke(1) + + assert _lambda_mock.call_count == 2 # retried + _lambda_mock.reset_mock() + + with pytest.raises(RuntimeError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).ainvoke(2) + + assert _lambda_mock.call_count == 1 # did not retry + _lambda_mock.reset_mock() + + with pytest.raises(ValueError): + await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).abatch([1, 2, 0]) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + _lambda_mock.reset_mock() + + output = await runnable.with_retry( + stop_after_attempt=2, + retry_if_exception_type=(ValueError,), + ).abatch([1, 2, 0], return_exceptions=True) + + # 3rd input isn't retried because it succeeded + assert _lambda_mock.call_count == 3 + 2 + assert len(output) == 3 + assert isinstance(output[0], ValueError) + assert isinstance(output[1], RuntimeError) + assert output[2] == 0 + _lambda_mock.reset_mock() + + +@freeze_time("2023-01-01") +def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: + class ControlledExceptionRunnable(Runnable[str, str]): + def __init__(self, fail_starts_with: str) -> None: + self.fail_starts_with = fail_starts_with + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + raise NotImplementedError() + + def _batch( + self, + inputs: List[str], + ) -> List: + outputs: List[Any] = [] + for input in inputs: + if input.startswith(self.fail_starts_with): + outputs.append(ValueError()) + else: + outputs.append(input + "a") + return outputs + + def batch( + self, + inputs: List[str], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + return self._batch_with_config( + self._batch, + inputs, + config, + return_exceptions=return_exceptions, + **kwargs, + ) + + chain = ( + ControlledExceptionRunnable("bux") + | ControlledExceptionRunnable("bar") + | ControlledExceptionRunnable("baz") + | ControlledExceptionRunnable("foo") + ) + + assert isinstance(chain, RunnableSequence) + + # Test batch + with pytest.raises(ValueError): + chain.batch(["foo", "bar", "baz", "qux"]) + + spy = mocker.spy(ControlledExceptionRunnable, "batch") + tracer = FakeTracer() + inputs = ["foo", "bar", "baz", "qux"] + outputs = chain.batch(inputs, dict(callbacks=[tracer]), return_exceptions=True) + assert len(outputs) == 4 + assert isinstance(outputs[0], ValueError) + assert isinstance(outputs[1], ValueError) + assert isinstance(outputs[2], ValueError) + assert outputs[3] == "quxaaaa" + assert spy.call_count == 4 + inputs_to_batch = [c[0][1] for c in spy.call_args_list] + assert inputs_to_batch == [ + # inputs to sequence step 0 + # same as inputs to sequence.batch() + ["foo", "bar", "baz", "qux"], + # inputs to sequence step 1 + # == outputs of sequence step 0 as no exceptions were raised + ["fooa", "bara", "baza", "quxa"], + # inputs to sequence step 2 + # 'bar' was dropped as it raised an exception in step 1 + ["fooaa", "bazaa", "quxaa"], + # inputs to sequence step 3 + # 'baz' was dropped as it raised an exception in step 2 + ["fooaaa", "quxaaa"], + ] + parent_runs = sorted( + (r for r in tracer.runs if r.parent_run_id is None), + key=lambda run: inputs.index(run.inputs["input"]), + ) + assert len(parent_runs) == 4 + + parent_run_foo = parent_runs[0] + assert parent_run_foo.inputs["input"] == "foo" + assert parent_run_foo.error == repr(ValueError()) + assert len(parent_run_foo.child_runs) == 4 + assert [r.error for r in parent_run_foo.child_runs] == [ + None, + None, + None, + repr(ValueError()), + ] + + parent_run_bar = parent_runs[1] + assert parent_run_bar.inputs["input"] == "bar" + assert parent_run_bar.error == repr(ValueError()) + assert len(parent_run_bar.child_runs) == 2 + assert [r.error for r in parent_run_bar.child_runs] == [ + None, + repr(ValueError()), + ] + + parent_run_baz = parent_runs[2] + assert parent_run_baz.inputs["input"] == "baz" + assert parent_run_baz.error == repr(ValueError()) + assert len(parent_run_baz.child_runs) == 3 + assert [r.error for r in parent_run_baz.child_runs] == [ + None, + None, + repr(ValueError()), + ] + + parent_run_qux = parent_runs[3] + assert parent_run_qux.inputs["input"] == "qux" + assert parent_run_qux.error is None + assert parent_run_qux.outputs["output"] == "quxaaaa" + assert len(parent_run_qux.child_runs) == 4 + assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None] + + +@pytest.mark.asyncio +@freeze_time("2023-01-01") +async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: + class ControlledExceptionRunnable(Runnable[str, str]): + def __init__(self, fail_starts_with: str) -> None: + self.fail_starts_with = fail_starts_with + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None) -> Any: + raise NotImplementedError() + + async def _abatch( + self, + inputs: List[str], + ) -> List: + outputs: List[Any] = [] + for input in inputs: + if input.startswith(self.fail_starts_with): + outputs.append(ValueError()) + else: + outputs.append(input + "a") + return outputs + + async def abatch( + self, + inputs: List[str], + config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, + *, + return_exceptions: bool = False, + **kwargs: Any, + ) -> List[str]: + return await self._abatch_with_config( + self._abatch, + inputs, + config, + return_exceptions=return_exceptions, + **kwargs, + ) + + chain = ( + ControlledExceptionRunnable("bux") + | ControlledExceptionRunnable("bar") + | ControlledExceptionRunnable("baz") + | ControlledExceptionRunnable("foo") + ) + + assert isinstance(chain, RunnableSequence) + + # Test abatch + with pytest.raises(ValueError): + await chain.abatch(["foo", "bar", "baz", "qux"]) + + spy = mocker.spy(ControlledExceptionRunnable, "abatch") + tracer = FakeTracer() + inputs = ["foo", "bar", "baz", "qux"] + outputs = await chain.abatch( + inputs, dict(callbacks=[tracer]), return_exceptions=True + ) + assert len(outputs) == 4 + assert isinstance(outputs[0], ValueError) + assert isinstance(outputs[1], ValueError) + assert isinstance(outputs[2], ValueError) + assert outputs[3] == "quxaaaa" + assert spy.call_count == 4 + inputs_to_batch = [c[0][1] for c in spy.call_args_list] + assert inputs_to_batch == [ + # inputs to sequence step 0 + # same as inputs to sequence.batch() + ["foo", "bar", "baz", "qux"], + # inputs to sequence step 1 + # == outputs of sequence step 0 as no exceptions were raised + ["fooa", "bara", "baza", "quxa"], + # inputs to sequence step 2 + # 'bar' was dropped as it raised an exception in step 1 + ["fooaa", "bazaa", "quxaa"], + # inputs to sequence step 3 + # 'baz' was dropped as it raised an exception in step 2 + ["fooaaa", "quxaaa"], + ] + parent_runs = sorted( + (r for r in tracer.runs if r.parent_run_id is None), + key=lambda run: inputs.index(run.inputs["input"]), + ) + assert len(parent_runs) == 4 + + parent_run_foo = parent_runs[0] + assert parent_run_foo.inputs["input"] == "foo" + assert parent_run_foo.error == repr(ValueError()) + assert len(parent_run_foo.child_runs) == 4 + assert [r.error for r in parent_run_foo.child_runs] == [ + None, + None, + None, + repr(ValueError()), + ] + + parent_run_bar = parent_runs[1] + assert parent_run_bar.inputs["input"] == "bar" + assert parent_run_bar.error == repr(ValueError()) + assert len(parent_run_bar.child_runs) == 2 + assert [r.error for r in parent_run_bar.child_runs] == [ + None, + repr(ValueError()), + ] + + parent_run_baz = parent_runs[2] + assert parent_run_baz.inputs["input"] == "baz" + assert parent_run_baz.error == repr(ValueError()) + assert len(parent_run_baz.child_runs) == 3 + assert [r.error for r in parent_run_baz.child_runs] == [ + None, + None, + repr(ValueError()), + ] + + parent_run_qux = parent_runs[3] + assert parent_run_qux.inputs["input"] == "qux" + assert parent_run_qux.error is None + assert parent_run_qux.outputs["output"] == "quxaaaa" + assert len(parent_run_qux.child_runs) == 4 + assert [r.error for r in parent_run_qux.child_runs] == [None, None, None, None]