You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/langchain_core/runnables/router.py

208 lines
6.1 KiB
Python

from __future__ import annotations
from typing import (
Any,
AsyncIterator,
Callable,
Iterator,
List,
Mapping,
Optional,
Union,
cast,
)
from typing_extensions import TypedDict
from langchain_core.runnables.base import (
Input,
Output,
Runnable,
RunnableSerializable,
coerce_to_runnable,
)
from langchain_core.runnables.config import (
RunnableConfig,
get_config_list,
get_executor_for_config,
)
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
gather_with_concurrency,
get_unique_config_specs,
)
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(RunnableSerializable[RouterInput, Output]):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Any, Output]]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.runnables.values() for spec in step.config_specs
)
def __init__(
self,
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
) -> None:
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
def invoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return runnable.invoke(input, config, **kwargs)
except Exception as e:
return e
else:
return runnable.invoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
with get_executor_for_config(configs[0]) as executor:
return cast(
List[Output],
list(executor.map(invoke, runnables, actual_inputs, configs)),
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
async def ainvoke(
runnable: Runnable, input: Input, config: RunnableConfig
) -> Union[Output, Exception]:
if return_exceptions:
try:
return await runnable.ainvoke(input, config, **kwargs)
except Exception as e:
return e
else:
return await runnable.ainvoke(input, config, **kwargs)
runnables = [self.runnables[key] for key in keys]
configs = get_config_list(config, len(inputs))
return await gather_with_concurrency(
configs[0].get("max_concurrency"),
*(
ainvoke(runnable, input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self,
input: RouterInput,
config: Optional[RunnableConfig] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output