From 029b2f6aac0fdc45d3521e388ea4dfbedc8f428c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 15 Sep 2023 17:37:27 +0100 Subject: [PATCH] Allow calls to batch() with 0 length arrays (#10627) This can happen if eg the input to batch is a list generated dynamically, where a 0-length list might be a valid use case --- libs/langchain/langchain/llms/base.py | 6 +++++ .../langchain/schema/runnable/base.py | 24 +++++++++++++++++++ .../langchain/schema/runnable/config.py | 4 ++-- .../langchain/schema/runnable/router.py | 6 +++++ 4 files changed, 38 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 3724db869f..5d6e074b8d 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -267,6 +267,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: + if not inputs: + return [] + config = get_config_list(config, len(inputs)) max_concurrency = config[0].get("max_concurrency") @@ -306,6 +309,9 @@ class BaseLLM(BaseLanguageModel[str], ABC): return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: + if not inputs: + return [] + if type(self)._agenerate == BaseLLM._agenerate: # model doesn't implement async batch, so use default implementation return await asyncio.get_running_loop().run_in_executor( diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 71a80b20ed..7771ed02ba 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -114,6 +114,9 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of batch, which calls invoke N times. Subclasses should override this method if they can batch more efficiently. """ + if not inputs: + return [] + configs = get_config_list(config, len(inputs)) def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: @@ -144,6 +147,9 @@ class Runnable(Generic[Input, Output], ABC): Default implementation of abatch, which calls ainvoke N times. Subclasses should override this method if they can batch more efficiently. """ + if not inputs: + return [] + configs = get_config_list(config, len(inputs)) async def ainvoke( @@ -376,6 +382,9 @@ class Runnable(Generic[Input, Output], ABC): ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" + if not input: + return [] + configs = get_config_list(config, len(input)) callback_managers = [get_callback_manager_for_config(c) for c in configs] run_managers = [ @@ -444,6 +453,9 @@ class Runnable(Generic[Input, Output], ABC): ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" + if not input: + return [] + configs = get_config_list(config, len(input)) callback_managers = [get_async_callback_manager_for_config(c) for c in configs] run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( @@ -748,6 +760,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): if return_exceptions: raise NotImplementedError() + if not inputs: + return [] + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -813,6 +828,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]): if return_exceptions: raise NotImplementedError() + if not inputs: + return [] + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -1004,6 +1022,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) -> List[Output]: from langchain.callbacks.manager import CallbackManager + if not inputs: + return [] + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ @@ -1122,6 +1143,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): AsyncCallbackManager, ) + if not inputs: + return [] + # setup callbacks configs = get_config_list(config, len(inputs)) callback_managers = [ diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 987a2c7d2f..5c89ff150a 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -97,8 +97,8 @@ def get_config_list( Helper method to get a list of configs from a single config or a list of configs, useful for subclasses overriding batch() or abatch(). """ - if length < 1: - raise ValueError(f"length must be >= 1, but got {length}") + if length < 0: + raise ValueError(f"length must be >= 0, but got {length}") if isinstance(config, list) and len(config) != length: raise ValueError( f"config must be a list of the same length as inputs, " diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index a51d0907ca..4f723d7ed3 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -129,6 +129,9 @@ class RouterRunnable( return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: + if not inputs: + return [] + keys = [input["key"] for input in inputs] actual_inputs = [input["input"] for input in inputs] if any(key not in self.runnables for key in keys): @@ -161,6 +164,9 @@ class RouterRunnable( return_exceptions: bool = False, **kwargs: Optional[Any], ) -> List[Output]: + if not inputs: + return [] + keys = [input["key"] for input in inputs] actual_inputs = [input["input"] for input in inputs] if any(key not in self.runnables for key in keys):