Allow calls to batch() with 0 length arrays (#10627)

This can happen if eg the input to batch is a list generated dynamically, where a 0-length list might be a valid use case
pull/10651/head
Nuno Campos 1 year ago committed by GitHub
parent a50e62e44b
commit 029b2f6aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -267,6 +267,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if not inputs:
return []
config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency")
@ -306,6 +309,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return_exceptions: bool = False,
**kwargs: Any,
) -> List[str]:
if not inputs:
return []
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(

@ -114,6 +114,9 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
if not inputs:
return []
configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
@ -144,6 +147,9 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently.
"""
if not inputs:
return []
configs = get_config_list(config, len(inputs))
async def ainvoke(
@ -376,6 +382,9 @@ class Runnable(Generic[Input, Output], ABC):
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
return []
configs = get_config_list(config, len(input))
callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [
@ -444,6 +453,9 @@ class Runnable(Generic[Input, Output], ABC):
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
return []
configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
@ -748,6 +760,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -813,6 +828,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
if return_exceptions:
raise NotImplementedError()
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -1004,6 +1022,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [
@ -1122,6 +1143,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
AsyncCallbackManager,
)
if not inputs:
return []
# setup callbacks
configs = get_config_list(config, len(inputs))
callback_managers = [

@ -97,8 +97,8 @@ def get_config_list(
Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch().
"""
if length < 1:
raise ValueError(f"length must be >= 1, but got {length}")
if length < 0:
raise ValueError(f"length must be >= 0, but got {length}")
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "

@ -129,6 +129,9 @@ class RouterRunnable(
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
@ -161,6 +164,9 @@ class RouterRunnable(
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):

Loading…
Cancel
Save