Add RunnableGenerator (#11214)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11204/head^2
Nuno Campos 10 months ago committed by GitHub
commit 4ad0f3de2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]
if got_first_val:
yield from self.stream(final, config, **kwargs)
@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]
if got_first_val:
async for output in self.astream(final, config, **kwargs):
@ -453,6 +455,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -465,7 +468,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
output = call_func_with_variable_args(func, input, run_manager, config)
output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
@ -486,6 +491,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
@ -499,7 +505,7 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
output = await acall_func_with_variable_args(
func, input, run_manager, config
func, input, run_manager, config, **kwargs
)
except BaseException as e:
await run_manager.on_chain_error(e)
@ -526,6 +532,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -546,7 +553,6 @@ class Runnable(Generic[Input, Output], ABC):
)
]
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
@ -597,6 +603,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -619,7 +626,6 @@ class Runnable(Generic[Input, Output], ABC):
)
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
@ -668,6 +674,7 @@ class Runnable(Generic[Input, Output], ABC):
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
@ -689,7 +696,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
@ -706,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
@ -716,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = None
final_input_supported = False
@ -746,6 +752,7 @@ class Runnable(Generic[Input, Output], ABC):
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
@ -767,7 +774,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
@ -784,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
@ -794,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
@ -1311,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
@ -1331,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
@ -1751,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] += other[key]
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
@ -1760,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
chunk[key] = chunk[key] + self[key]
return chunk
@ -2061,6 +2069,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
yield chunk
class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""
def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return getattr(first_param.annotation, "__args__", (Any,))[0]
else:
return Any
except ValueError:
return Any
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
getattr(sig.return_annotation, "__args__", (Any,))[0]
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False
def __repr__(self) -> str:
return "RunnableGenerator(...)"
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)
def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")
return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)
def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
return self.atransform(input_aiter(), config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
async for output in self.astream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)
class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
@ -2538,6 +2679,8 @@ RunnableLike = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any],
]
@ -2545,8 +2688,10 @@ RunnableLike = Union[
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(thing)
return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict):
runnables: Mapping[str, Runnable[Any, Any]] = {
key: coerce_to_runnable(r) for key, r in thing.items()

@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):

@ -15,14 +15,7 @@ from typing import (
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import (
Input,
Other,
Output,
Runnable,
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:

@ -1,6 +1,16 @@
import sys
from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID
import pytest
@ -46,6 +56,7 @@ from langchain.schema.runnable import (
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.base import RunnableGenerator
from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None:
"title": "PromptInput",
"type": "object",
}
@pytest.mark.asyncio
async def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable."""
def gen(input: Iterator[Any]) -> Iterator[int]:
yield 1
yield 2
yield 3
runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
assert runnable.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert runnable.invoke(None) == 6
assert list(runnable.stream(None)) == [1, 2, 3]
assert runnable.batch([None, None]) == [6, 6]
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
arunnable = RunnableGenerator(agen)
assert await arunnable.ainvoke(None) == 6
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert await arunnable.abatch([None, None]) == [6, 6]
@pytest.mark.asyncio
async def test_runnable_gen_transform() -> None:
"""Test that a generator can be used as a runnable."""
def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]:
for i in range(next(length_iter)):
yield i
async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]:
async for length in length_iter:
for i in range(length):
yield i
def plus_one(input: Iterator[int]) -> Iterator[int]:
for i in input:
yield i + 1
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in input:
yield i + 1
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert chain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert achain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert achain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert list(chain.stream(3)) == [1, 2, 3]
assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]

Loading…
Cancel
Save