mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
make runnable dir (#9016)
Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
c7a489ae0d
commit
434a96415b
24
libs/langchain/langchain/schema/runnable/__init__.py
Normal file
24
libs/langchain/langchain/schema/runnable/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
from langchain.schema.runnable.base import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableLambda,
|
||||
RunnableMap,
|
||||
RunnableSequence,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.passthrough import RunnablePassthrough
|
||||
from langchain.schema.runnable.router import RouterInput, RouterRunnable
|
||||
|
||||
__all__ = [
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableBinding",
|
||||
"RunnableConfig",
|
||||
"RunnableMap",
|
||||
"RunnableLambda",
|
||||
"RunnablePassthrough",
|
||||
"RunnableSequence",
|
||||
"RunnableWithFallbacks",
|
||||
]
|
@ -9,7 +9,6 @@ from typing import (
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
@ -19,7 +18,6 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
@ -27,48 +25,15 @@ from typing import (
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackManager, Callbacks
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.utils import (
|
||||
gather_with_concurrency,
|
||||
)
|
||||
from langchain.utils.aiter import atee, py_anext
|
||||
|
||||
|
||||
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
|
||||
tags: List[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
||||
|
||||
|
||||
Input = TypeVar("Input")
|
||||
# Output type should implement __concat__, as eg str, list, dict do
|
||||
Output = TypeVar("Output")
|
||||
@ -87,7 +52,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
@ -97,7 +62,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@ -150,7 +115,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
coros = map(self.ainvoke, inputs, configs)
|
||||
|
||||
return await _gather_with_concurrency(max_concurrency, *coros)
|
||||
return await gather_with_concurrency(max_concurrency, *coros)
|
||||
|
||||
def stream(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
@ -478,6 +443,14 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
@property
|
||||
def runnables(self) -> Iterator[Runnable[Input, Output]]:
|
||||
yield self.runnable
|
||||
@ -506,7 +479,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
output = runnable.invoke(
|
||||
input,
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -550,7 +523,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
output = await runnable.ainvoke(
|
||||
input,
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except self.exceptions_to_handle as e:
|
||||
if first_error is None:
|
||||
@ -606,7 +579,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
@ -673,7 +646,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
@ -716,6 +689,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -737,7 +714,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last],
|
||||
last=_coerce_to_runnable(other),
|
||||
last=coerce_to_runnable(other),
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
@ -756,7 +733,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=_coerce_to_runnable(other),
|
||||
first=coerce_to_runnable(other),
|
||||
middle=[self.first] + self.middle,
|
||||
last=self.last,
|
||||
)
|
||||
@ -786,7 +763,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -825,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
# finish the root run
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
@ -875,7 +852,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
@ -934,7 +911,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
inputs,
|
||||
[
|
||||
# each step a child run of the corresponding root run
|
||||
_patch_config(config, rm.get_child())
|
||||
patch_config(config, rm.get_child())
|
||||
for rm, config in zip(run_managers, configs)
|
||||
],
|
||||
max_concurrency=max_concurrency,
|
||||
@ -990,7 +967,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = step.invoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
@ -1002,12 +979,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].stream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
input, patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.transform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
final_pipeline, patch_config(config, run_manager.get_child())
|
||||
)
|
||||
for output in final_pipeline:
|
||||
yield output
|
||||
@ -1067,7 +1044,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
input = await step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
@ -1079,12 +1056,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
try:
|
||||
# stream the first of the last steps with non-streaming input
|
||||
final_pipeline = steps[streaming_start_index].astream(
|
||||
input, _patch_config(config, run_manager.get_child())
|
||||
input, patch_config(config, run_manager.get_child())
|
||||
)
|
||||
# stream the rest of the last steps with streaming input
|
||||
for step in steps[streaming_start_index + 1 :]:
|
||||
final_pipeline = step.atransform(
|
||||
final_pipeline, _patch_config(config, run_manager.get_child())
|
||||
final_pipeline, patch_config(config, run_manager.get_child())
|
||||
)
|
||||
async for output in final_pipeline:
|
||||
yield output
|
||||
@ -1128,14 +1105,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
],
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
|
||||
)
|
||||
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@ -1168,7 +1147,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
step.invoke,
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
]
|
||||
@ -1211,7 +1190,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
||||
step.ainvoke(
|
||||
input,
|
||||
# mark each step as a child run
|
||||
_patch_config(config, run_manager.get_child()),
|
||||
patch_config(config, run_manager.get_child()),
|
||||
)
|
||||
for step in steps.values()
|
||||
)
|
||||
@ -1250,19 +1229,6 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return self._call_with_config(self.func, input, config)
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
"""
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
@ -1279,6 +1245,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
|
||||
|
||||
@ -1335,160 +1305,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
"""A Router input.
|
||||
|
||||
Attributes:
|
||||
key: The key to route on.
|
||||
input: The input to pass to the selected runnable.
|
||||
"""
|
||||
|
||||
key: str
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[RouterInput, Other]:
|
||||
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return await runnable.ainvoke(actual_input, config)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
lambda runnable, input, config: runnable.invoke(input, config),
|
||||
runnables,
|
||||
actual_inputs,
|
||||
configs,
|
||||
)
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
return await _gather_with_concurrency(
|
||||
max_concurrency,
|
||||
*(
|
||||
runnable.ainvoke(input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
yield from runnable.stream(actual_input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
async for output in runnable.astream(actual_input, config):
|
||||
yield output
|
||||
|
||||
|
||||
def _patch_config(
|
||||
def patch_config(
|
||||
config: RunnableConfig, callback_manager: BaseCallbackManager
|
||||
) -> RunnableConfig:
|
||||
config = config.copy()
|
||||
@ -1496,7 +1313,7 @@ def _patch_config(
|
||||
return config
|
||||
|
||||
|
||||
def _coerce_to_runnable(
|
||||
def coerce_to_runnable(
|
||||
thing: Union[
|
||||
Runnable[Input, Output],
|
||||
Callable[[Input], Output],
|
||||
@ -1508,7 +1325,7 @@ def _coerce_to_runnable(
|
||||
elif callable(thing):
|
||||
return RunnableLambda(thing)
|
||||
elif isinstance(thing, dict):
|
||||
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
|
||||
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
|
||||
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
||||
else:
|
||||
raise TypeError(
|
27
libs/langchain/langchain/schema/runnable/config.py
Normal file
27
libs/langchain/langchain/schema/runnable/config.py
Normal file
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
|
||||
from langchain.callbacks.base import Callbacks
|
||||
|
||||
|
||||
class RunnableConfig(TypedDict, total=False):
|
||||
"""Configuration for a Runnable."""
|
||||
|
||||
tags: List[str]
|
||||
"""
|
||||
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
You can use these to filter calls.
|
||||
"""
|
||||
|
||||
metadata: Dict[str, Any]
|
||||
"""
|
||||
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Keys should be strings, values should be JSON-serializable.
|
||||
"""
|
||||
|
||||
callbacks: Callbacks
|
||||
"""
|
||||
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
|
||||
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
|
||||
"""
|
23
libs/langchain/langchain/schema/runnable/passthrough.py
Normal file
23
libs/langchain/langchain/schema/runnable/passthrough.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
"""
|
||||
A runnable that passes through the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(lambda x: x, input, config)
|
184
libs/langchain/langchain/schema/runnable/router.py
Normal file
184
libs/langchain/langchain/schema/runnable/router.py
Normal file
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import (
|
||||
Input,
|
||||
Other,
|
||||
Output,
|
||||
Runnable,
|
||||
RunnableSequence,
|
||||
coerce_to_runnable,
|
||||
)
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
from langchain.schema.runnable.utils import gather_with_concurrency
|
||||
|
||||
|
||||
class RouterInput(TypedDict):
|
||||
"""A Router input.
|
||||
|
||||
Attributes:
|
||||
key: The key to route on.
|
||||
input: The input to pass to the selected runnable.
|
||||
"""
|
||||
|
||||
key: str
|
||||
input: Any
|
||||
|
||||
|
||||
class RouterRunnable(
|
||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
||||
):
|
||||
"""
|
||||
A runnable that routes to a set of runnables based on Input['key'].
|
||||
Returns the output of the selected runnable.
|
||||
"""
|
||||
|
||||
runnables: Mapping[str, Runnable[Input, Output]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runnables: Mapping[
|
||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
||||
],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[RouterInput, Other]:
|
||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
other: Union[
|
||||
Runnable[Other, Any],
|
||||
Callable[[Any], Other],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
||||
Mapping[str, Any],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||
|
||||
def invoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return runnable.invoke(actual_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Output:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
return await runnable.ainvoke(actual_input, config)
|
||||
|
||||
def batch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
return list(
|
||||
executor.map(
|
||||
lambda runnable, input, config: runnable.invoke(input, config),
|
||||
runnables,
|
||||
actual_inputs,
|
||||
configs,
|
||||
)
|
||||
)
|
||||
|
||||
async def abatch(
|
||||
self,
|
||||
inputs: List[RouterInput],
|
||||
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
|
||||
*,
|
||||
max_concurrency: Optional[int] = None,
|
||||
) -> List[Output]:
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
raise ValueError("One or more keys do not have a corresponding runnable")
|
||||
|
||||
runnables = [self.runnables[key] for key in keys]
|
||||
configs = self._get_config_list(config, len(inputs))
|
||||
return await gather_with_concurrency(
|
||||
max_concurrency,
|
||||
*(
|
||||
runnable.ainvoke(input, config)
|
||||
for runnable, input, config in zip(runnables, actual_inputs, configs)
|
||||
),
|
||||
)
|
||||
|
||||
def stream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> Iterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
yield from runnable.stream(actual_input, config)
|
||||
|
||||
async def astream(
|
||||
self, input: RouterInput, config: Optional[RunnableConfig] = None
|
||||
) -> AsyncIterator[Output]:
|
||||
key = input["key"]
|
||||
actual_input = input["input"]
|
||||
if key not in self.runnables:
|
||||
raise ValueError(f"No runnable associated with key '{key}'")
|
||||
|
||||
runnable = self.runnables[key]
|
||||
async for output in runnable.astream(actual_input, config):
|
||||
yield output
|
18
libs/langchain/langchain/schema/runnable/utils.py
Normal file
18
libs/langchain/langchain/schema/runnable/utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Coroutine, Union
|
||||
|
||||
|
||||
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
||||
async with semaphore:
|
||||
return await coro
|
||||
|
||||
|
||||
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
||||
if n is None:
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
semaphore = asyncio.Semaphore(n)
|
||||
|
||||
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))
|
File diff suppressed because one or more lines are too long
@ -839,7 +839,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_with_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any
|
||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
@ -848,3 +848,4 @@ async def test_llm_with_fallbacks(
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
Loading…
Reference in New Issue
Block a user