Add RunnableGenerator (#11214)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11204/head^2
Nuno Campos 11 months ago committed by GitHub
commit 4ad0f3de2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC):
other: Union[ other: Union[
Runnable[Any, Other], Runnable[Any, Other],
Callable[[Any], Other], Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
], ],
) -> RunnableSequence[Input, Other]: ) -> RunnableSequence[Input, Other]:
@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC):
self, self,
other: Union[ other: Union[
Runnable[Other, Any], Runnable[Other, Any],
Callable[[Any], Other], Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
], ],
) -> RunnableSequence[Other, Output]: ) -> RunnableSequence[Other, Output]:
@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC):
else: else:
# Make a best effort to gather, for any type that supports `+` # Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails. # This method should throw an error if gathering fails.
final += chunk # type: ignore[operator] final = final + chunk # type: ignore[operator]
if got_first_val: if got_first_val:
yield from self.stream(final, config, **kwargs) yield from self.stream(final, config, **kwargs)
@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC):
else: else:
# Make a best effort to gather, for any type that supports `+` # Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails. # This method should throw an error if gathering fails.
final += chunk # type: ignore[operator] final = final + chunk # type: ignore[operator]
if got_first_val: if got_first_val:
async for output in self.astream(final, config, **kwargs): async for output in self.astream(final, config, **kwargs):
@ -453,6 +455,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -465,7 +468,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
output = call_func_with_variable_args(func, input, run_manager, config) output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
@ -486,6 +491,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output: ) -> Output:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses.""" with callbacks. Use this method to implement ainvoke() in subclasses."""
@ -499,7 +505,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
output = await acall_func_with_variable_args( output = await acall_func_with_variable_args(
func, input, run_manager, config func, input, run_manager, config, **kwargs
) )
except BaseException as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -526,6 +532,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -546,7 +553,6 @@ class Runnable(Generic[Input, Output], ABC):
) )
] ]
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = [ kwargs["config"] = [
patch_config(c, callbacks=rm.get_child()) patch_config(c, callbacks=rm.get_child())
@ -597,6 +603,7 @@ class Runnable(Generic[Input, Output], ABC):
*, *,
return_exceptions: bool = False, return_exceptions: bool = False,
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
@ -619,7 +626,6 @@ class Runnable(Generic[Input, Output], ABC):
) )
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = [ kwargs["config"] = [
patch_config(c, callbacks=rm.get_child()) patch_config(c, callbacks=rm.get_child())
@ -668,6 +674,7 @@ class Runnable(Generic[Input, Output], ABC):
], ],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]: ) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of """Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks. Output values, with callbacks.
@ -689,7 +696,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child() config, callbacks=run_manager.get_child()
@ -706,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk final_output = chunk
else: else:
try: try:
final_output += chunk # type: ignore[operator] final_output = final_output + chunk # type: ignore
except TypeError: except TypeError:
final_output = None final_output = None
final_output_supported = False final_output_supported = False
@ -716,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk final_input = ichunk
else: else:
try: try:
final_input += ichunk # type: ignore[operator] final_input = final_input + ichunk # type: ignore
except TypeError: except TypeError:
final_input = None final_input = None
final_input_supported = False final_input_supported = False
@ -746,6 +752,7 @@ class Runnable(Generic[Input, Output], ABC):
], ],
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None, run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]: ) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async """Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks. Iterator of Output values, with callbacks.
@ -767,7 +774,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"), name=config.get("run_name"),
) )
try: try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer): if accepts_config(transformer):
kwargs["config"] = patch_config( kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child() config, callbacks=run_manager.get_child()
@ -784,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk final_output = chunk
else: else:
try: try:
final_output += chunk # type: ignore[operator] final_output = final_output + chunk # type: ignore
except TypeError: except TypeError:
final_output = None final_output = None
final_output_supported = False final_output_supported = False
@ -794,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk final_input = ichunk
else: else:
try: try:
final_input += ichunk # type: ignore[operator] final_input = final_input + ichunk # type: ignore[operator]
except TypeError: except TypeError:
final_input = None final_input = None
final_input_supported = False final_input_supported = False
@ -1311,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
other: Union[ other: Union[
Runnable[Any, Other], Runnable[Any, Other],
Callable[[Any], Other], Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
], ],
) -> RunnableSequence[Input, Other]: ) -> RunnableSequence[Input, Other]:
@ -1331,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self, self,
other: Union[ other: Union[
Runnable[Other, Any], Runnable[Other, Any],
Callable[[Any], Other], Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
], ],
) -> RunnableSequence[Other, Output]: ) -> RunnableSequence[Other, Output]:
@ -1751,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None: if key not in chunk or chunk[key] is None:
chunk[key] = other[key] chunk[key] = other[key]
elif other[key] is not None: elif other[key] is not None:
chunk[key] += other[key] chunk[key] = chunk[key] + other[key]
return chunk return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
@ -1760,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None: if key not in chunk or chunk[key] is None:
chunk[key] = self[key] chunk[key] = self[key]
elif self[key] is not None: elif self[key] is not None:
chunk[key] += self[key] chunk[key] = chunk[key] + self[key]
return chunk return chunk
@ -2061,6 +2069,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
yield chunk yield chunk
class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""
def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return getattr(first_param.annotation, "__args__", (Any,))[0]
else:
return Any
except ValueError:
return Any
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
getattr(sig.return_annotation, "__args__", (Any,))[0]
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False
def __repr__(self) -> str:
return "RunnableGenerator(...)"
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)
def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")
return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)
def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
return self.atransform(input_aiter(), config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
async for output in self.astream(input, config, **kwargs):
if final is None:
final = output
else:
final = final + output
return cast(Output, final)
class RunnableLambda(Runnable[Input, Output]): class RunnableLambda(Runnable[Input, Output]):
""" """
A runnable that runs a callable. A runnable that runs a callable.
@ -2538,6 +2679,8 @@ RunnableLike = Union[
Runnable[Input, Output], Runnable[Input, Output],
Callable[[Input], Output], Callable[[Input], Output],
Callable[[Input], Awaitable[Output]], Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any], Mapping[str, Any],
] ]
@ -2545,8 +2688,10 @@ RunnableLike = Union[
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable): if isinstance(thing, Runnable):
return thing return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing): elif callable(thing):
return RunnableLambda(thing) return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict): elif isinstance(thing, dict):
runnables: Mapping[str, Runnable[Any, Any]] = { runnables: Mapping[str, Runnable[Any, Any]] = {
key: coerce_to_runnable(r) for key, r in thing.items() key: coerce_to_runnable(r) for key, r in thing.items()

@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input, input: Input,
run_manager: CallbackManagerForChainRun, run_manager: CallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any,
) -> Output: ) -> Output:
"""Call function that may optionally accept a run_manager and/or config.""" """Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func): if accepts_run_manager(func):
@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input, input: Input,
run_manager: AsyncCallbackManagerForChainRun, run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig, config: RunnableConfig,
**kwargs: Any,
) -> Output: ) -> Output:
"""Call function that may optionally accept a run_manager and/or config.""" """Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func): if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func): if accepts_run_manager(func):

@ -15,14 +15,7 @@ from typing import (
from typing_extensions import TypedDict from typing_extensions import TypedDict
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
Input,
Other,
Output,
Runnable,
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.config import ( from langchain.schema.runnable.config import (
RunnableConfig, RunnableConfig,
get_config_list, get_config_list,
@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
def invoke( def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output: ) -> Output:

@ -1,6 +1,16 @@
import sys import sys
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
from uuid import UUID from uuid import UUID
import pytest import pytest
@ -46,6 +56,7 @@ from langchain.schema.runnable import (
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.schema.runnable.base import RunnableGenerator
from langchain.tools.base import BaseTool, tool from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None:
"title": "PromptInput", "title": "PromptInput",
"type": "object", "type": "object",
} }
@pytest.mark.asyncio
async def test_runnable_gen() -> None:
"""Test that a generator can be used as a runnable."""
def gen(input: Iterator[Any]) -> Iterator[int]:
yield 1
yield 2
yield 3
runnable = RunnableGenerator(gen)
assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"}
assert runnable.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert runnable.invoke(None) == 6
assert list(runnable.stream(None)) == [1, 2, 3]
assert runnable.batch([None, None]) == [6, 6]
async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]:
yield 1
yield 2
yield 3
arunnable = RunnableGenerator(agen)
assert await arunnable.ainvoke(None) == 6
assert [p async for p in arunnable.astream(None)] == [1, 2, 3]
assert await arunnable.abatch([None, None]) == [6, 6]
@pytest.mark.asyncio
async def test_runnable_gen_transform() -> None:
"""Test that a generator can be used as a runnable."""
def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]:
for i in range(next(length_iter)):
yield i
async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]:
async for length in length_iter:
for i in range(length):
yield i
def plus_one(input: Iterator[int]) -> Iterator[int]:
for i in input:
yield i + 1
async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]:
async for i in input:
yield i + 1
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert chain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert achain.input_schema.schema() == {
"title": "RunnableGeneratorInput",
"type": "integer",
}
assert achain.output_schema.schema() == {
"title": "RunnableGeneratorOutput",
"type": "integer",
}
assert list(chain.stream(3)) == [1, 2, 3]
assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]

Loading…
Cancel
Save