make runnable dir (#9016)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
Bagatur 2023-08-10 00:56:37 -07:00 committed by GitHub
parent c7a489ae0d
commit 434a96415b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 590 additions and 235 deletions

View File

@ -0,0 +1,24 @@
from langchain.schema.runnable.base import (
Runnable,
RunnableBinding,
RunnableLambda,
RunnableMap,
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.passthrough import RunnablePassthrough
from langchain.schema.runnable.router import RouterInput, RouterRunnable
__all__ = [
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableBinding",
"RunnableConfig",
"RunnableMap",
"RunnableLambda",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
]

View File

@ -9,7 +9,6 @@ from typing import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
@ -19,7 +18,6 @@ from typing import (
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
cast,
@ -27,48 +25,15 @@ from typing import (
from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.callbacks.base import BaseCallbackManager
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import (
gather_with_concurrency,
)
from langchain.utils.aiter import atee, py_anext
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
@ -87,7 +52,7 @@ class Runnable(Generic[Input, Output], ABC):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
@ -97,7 +62,7 @@ class Runnable(Generic[Input, Output], ABC):
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
return RunnableSequence(first=coerce_to_runnable(other), last=self)
""" --- Public API --- """
@ -150,7 +115,7 @@ class Runnable(Generic[Input, Output], ABC):
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros)
return await gather_with_concurrency(max_concurrency, *coros)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
@ -478,6 +443,14 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:
yield self.runnable
@ -506,7 +479,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = runnable.invoke(
input,
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -550,7 +523,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
try:
output = await runnable.ainvoke(
input,
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
except self.exceptions_to_handle as e:
if first_error is None:
@ -606,7 +579,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
@ -673,7 +646,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
@ -716,6 +689,10 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
class Config:
arbitrary_types_allowed = True
@ -737,7 +714,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last],
last=_coerce_to_runnable(other),
last=coerce_to_runnable(other),
)
def __ror__(
@ -756,7 +733,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
else:
return RunnableSequence(
first=_coerce_to_runnable(other),
first=coerce_to_runnable(other),
middle=[self.first] + self.middle,
last=self.last,
)
@ -786,7 +763,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -825,7 +802,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
@ -875,7 +852,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
@ -934,7 +911,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
@ -990,7 +967,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
@ -1002,12 +979,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream(
input, _patch_config(config, run_manager.get_child())
input, patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform(
final_pipeline, _patch_config(config, run_manager.get_child())
final_pipeline, patch_config(config, run_manager.get_child())
)
for output in final_pipeline:
yield output
@ -1067,7 +1044,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
@ -1079,12 +1056,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try:
# stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream(
input, _patch_config(config, run_manager.get_child())
input, patch_config(config, run_manager.get_child())
)
# stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform(
final_pipeline, _patch_config(config, run_manager.get_child())
final_pipeline, patch_config(config, run_manager.get_child())
)
async for output in final_pipeline:
yield output
@ -1128,14 +1105,16 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
],
],
) -> None:
super().__init__(
steps={key: _coerce_to_runnable(r) for key, r in steps.items()}
)
super().__init__(steps={key: coerce_to_runnable(r) for key, r in steps.items()})
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
class Config:
arbitrary_types_allowed = True
@ -1168,7 +1147,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.invoke,
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
@ -1211,7 +1190,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
patch_config(config, run_manager.get_child()),
)
for step in steps.values()
)
@ -1250,19 +1229,6 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(self.func, input, config)
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
@ -1279,6 +1245,10 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(bound=self.bound, kwargs={**self.kwargs, **kwargs})
@ -1335,160 +1305,7 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
yield item
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: _coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await _gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output
def _patch_config(
def patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:
config = config.copy()
@ -1496,7 +1313,7 @@ def _patch_config(
return config
def _coerce_to_runnable(
def coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
@ -1508,7 +1325,7 @@ def _coerce_to_runnable(
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else:
raise TypeError(

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import Any, Dict, List, TypedDict
from langchain.callbacks.base import Callbacks
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import List, Optional
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
"""
A runnable that passes through the input.
"""
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)

View File

@ -0,0 +1,184 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Generic,
Iterator,
List,
Mapping,
Optional,
TypedDict,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import (
Input,
Other,
Output,
Runnable,
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.runnable.utils import gather_with_concurrency
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None:
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return runnable.invoke(actual_input, config)
async def ainvoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
return await runnable.ainvoke(actual_input, config)
def batch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(
executor.map(
lambda runnable, input, config: runnable.invoke(input, config),
runnables,
actual_inputs,
configs,
)
)
async def abatch(
self,
inputs: List[RouterInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys):
raise ValueError("One or more keys do not have a corresponding runnable")
runnables = [self.runnables[key] for key in keys]
configs = self._get_config_list(config, len(inputs))
return await gather_with_concurrency(
max_concurrency,
*(
runnable.ainvoke(input, config)
for runnable, input, config in zip(runnables, actual_inputs, configs)
),
)
def stream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
yield from runnable.stream(actual_input, config)
async def astream(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
key = input["key"]
actual_input = input["input"]
if key not in self.runnables:
raise ValueError(f"No runnable associated with key '{key}'")
runnable = self.runnables[key]
async for output in runnable.astream(actual_input, config):
yield output

View File

@ -0,0 +1,18 @@
from __future__ import annotations
import asyncio
from typing import Any, Coroutine, Union
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(gated_coro(semaphore, c) for c in coros))

File diff suppressed because one or more lines are too long

View File

@ -839,7 +839,7 @@ def llm_chain_with_fallbacks() -> RunnableSequence:
)
@pytest.mark.asyncio
async def test_llm_with_fallbacks(
runnable: RunnableWithFallbacks, request: Any
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
assert runnable.invoke("hello") == "bar"
@ -848,3 +848,4 @@ async def test_llm_with_fallbacks(
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot