From b67db8deaa148c6a303d9b8998ead48c2096cc17 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 11:44:07 +0100 Subject: [PATCH 1/3] Add RunnableGenerator --- .../langchain/schema/runnable/base.py | 153 +++++++++++++++++- .../langchain/schema/runnable/config.py | 4 +- 2 files changed, 149 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7ee4be6960..9d45d86551 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -453,6 +453,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -465,7 +466,9 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - output = call_func_with_variable_args(func, input, run_manager, config) + output = call_func_with_variable_args( + func, input, run_manager, config, **kwargs + ) except BaseException as e: run_manager.on_chain_error(e) raise @@ -486,6 +489,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" @@ -499,7 +503,7 @@ class Runnable(Generic[Input, Output], ABC): ) try: output = await acall_func_with_variable_args( - func, input, run_manager, config + func, input, run_manager, config, **kwargs ) except BaseException as e: await run_manager.on_chain_error(e) @@ -526,6 +530,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -546,7 +551,6 @@ class Runnable(Generic[Input, Output], ABC): ) ] try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -597,6 +601,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -619,7 +624,6 @@ class Runnable(Generic[Input, Output], ABC): ) ) try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -668,6 +672,7 @@ class Runnable(Generic[Input, Output], ABC): ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. @@ -689,7 +694,6 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -746,6 +750,7 @@ class Runnable(Generic[Input, Output], ABC): ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. @@ -767,7 +772,6 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -2061,6 +2065,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): yield chunk +class RunnableGenerator(Runnable[Input, Output]): + """ + A runnable that runs a generator function. + """ + + def __init__( + self, + transform: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + ], + atransform: Optional[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + ] = None, + ) -> None: + if atransform is not None: + self._atransform = atransform + + if inspect.isasyncgenfunction(transform): + self._atransform = transform + elif inspect.isgeneratorfunction(transform): + self._transform = transform + else: + raise TypeError( + "Expected a generator function type for `transform`." + f"Instead got an unsupported type: {type(transform)}" + ) + + @property + def InputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + params = inspect.signature(func).parameters + first_param = next(iter(params.values()), None) + if first_param and first_param.annotation != inspect.Parameter.empty: + return first_param.annotation + else: + return Any + except ValueError: + return Any + + @property + def OutputType(self) -> Type[Output]: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + sig = inspect.signature(func) + return ( + sig.return_annotation + if sig.return_annotation != inspect.Signature.empty + else Any + ) + except ValueError: + return Any + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableGenerator): + if hasattr(self, "_transform") and hasattr(other, "_transform"): + return self._transform == other._transform + elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): + return self._atransform == other._atransform + else: + return False + else: + return False + + def __repr__(self) -> str: + return "RunnableGenerator(...)" + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Output]: + return self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> Iterator[Output]: + return self.transform(iter([input]), config, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + for output in self.stream(input, config, **kwargs): + if final is None: + final = output + else: + final += output + return final + + async def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> AsyncIterator[Output]: + if not hasattr(self, "_atransform"): + raise NotImplementedError("This runnable does not support async methods.") + + return self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ) + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any | None, + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + return self.atransform(input_aiter(), config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None + ) -> Output: + final = None + async for output in self.astream(input, config): + if final is None: + final = output + else: + final += output + return final + + class RunnableLambda(Runnable[Input, Output]): """ A runnable that runs a callable. @@ -2538,6 +2675,8 @@ RunnableLike = Union[ Runnable[Input, Output], Callable[[Input], Output], Callable[[Input], Awaitable[Output]], + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], Mapping[str, Any], ] @@ -2545,6 +2684,8 @@ RunnableLike = Union[ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: if isinstance(thing, Runnable): return thing + elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + return RunnableGenerator(thing) elif callable(thing): return RunnableLambda(thing) elif isinstance(thing, dict): diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 6ae120ad7f..06d979cff0 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -152,9 +152,9 @@ def call_func_with_variable_args( input: Input, run_manager: CallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func): @@ -174,9 +174,9 @@ async def acall_func_with_variable_args( input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func): From 0318cdd33c3ac85b4b26094a356c3084cf800efe Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 12:25:19 +0100 Subject: [PATCH 2/3] Add tests --- .../langchain/schema/runnable/base.py | 8 +- .../schema/runnable/test_runnable.py | 91 ++++++++++++++++++- 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 9d45d86551..558495173b 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2100,7 +2100,7 @@ class RunnableGenerator(Runnable[Input, Output]): params = inspect.signature(func).parameters first_param = next(iter(params.values()), None) if first_param and first_param.annotation != inspect.Parameter.empty: - return first_param.annotation + return getattr(first_param.annotation, "__args__", (Any,))[0] else: return Any except ValueError: @@ -2112,7 +2112,7 @@ class RunnableGenerator(Runnable[Input, Output]): try: sig = inspect.signature(func) return ( - sig.return_annotation + getattr(sig.return_annotation, "__args__", (Any,))[0] if sig.return_annotation != inspect.Signature.empty else Any ) @@ -2162,7 +2162,7 @@ class RunnableGenerator(Runnable[Input, Output]): final += output return final - async def atransform( + def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, @@ -2175,7 +2175,7 @@ class RunnableGenerator(Runnable[Input, Output]): input, self._atransform, config, **kwargs ) - async def astream( + def astream( self, input: Input, config: Optional[RunnableConfig] = None, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index b72144102a..316a9ecad6 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,7 +1,18 @@ import sys from operator import itemgetter -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + cast, +) from uuid import UUID +from langchain.schema.runnable.base import RunnableGenerator import pytest from freezegun import freeze_time @@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None: "title": "PromptInput", "type": "object", } + + +@pytest.mark.asyncio +async def test_runnable_gen() -> None: + """Test that a generator can be used as a runnable.""" + + def gen(input: Iterator[Any]) -> Iterator[int]: + yield 1 + yield 2 + yield 3 + + runnable = RunnableGenerator(gen) + + assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"} + assert runnable.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert runnable.invoke(None) == 6 + assert list(runnable.stream(None)) == [1, 2, 3] + assert runnable.batch([None, None]) == [6, 6] + + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield 1 + yield 2 + yield 3 + + arunnable = RunnableGenerator(agen) + + assert await arunnable.ainvoke(None) == 6 + assert [p async for p in arunnable.astream(None)] == [1, 2, 3] + assert await arunnable.abatch([None, None]) == [6, 6] + + +@pytest.mark.asyncio +async def test_runnable_gen_transform() -> None: + """Test that a generator can be used as a runnable.""" + + def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]: + for i in range(next(length_iter)): + yield i + + async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]: + async for length in length_iter: + for i in range(length): + yield i + + def plus_one(input: Iterator[int]) -> Iterator[int]: + for i in input: + yield i + 1 + + async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: + async for i in input: + yield i + 1 + + chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one + achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one + + assert chain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert chain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + assert achain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert achain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert list(chain.stream(3)) == [1, 2, 3] + assert [p async for p in achain.astream(4)] == [1, 2, 3, 4] From 2387647d30bb1aef0a7edace6a1aa15b7529d652 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 14:11:03 +0100 Subject: [PATCH 3/3] Lint --- .../langchain/schema/runnable/base.py | 48 ++++++++++--------- .../langchain/schema/runnable/router.py | 31 +----------- .../schema/runnable/test_runnable.py | 4 +- 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 558495173b..9b69a9f711 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: yield from self.stream(final, config, **kwargs) @@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -710,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -720,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore except TypeError: final_input = None final_input_supported = False @@ -788,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -798,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore[operator] except TypeError: final_input = None final_input_supported = False @@ -1315,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -1335,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -1755,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = other[key] elif other[key] is not None: - chunk[key] += other[key] + chunk[key] = chunk[key] + other[key] return chunk def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: @@ -1764,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = self[key] elif self[key] is not None: - chunk[key] += self[key] + chunk[key] = chunk[key] + self[key] return chunk @@ -2107,7 +2111,7 @@ class RunnableGenerator(Runnable[Input, Output]): return Any @property - def OutputType(self) -> Type[Output]: + def OutputType(self) -> Any: func = getattr(self, "_transform", None) or getattr(self, "_atransform") try: sig = inspect.signature(func) @@ -2137,7 +2141,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Iterator[Input], config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> Iterator[Output]: return self._transform_stream_with_config( input, self._transform, config, **kwargs @@ -2147,7 +2151,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> Iterator[Output]: return self.transform(iter([input]), config, **kwargs) @@ -2159,14 +2163,14 @@ class RunnableGenerator(Runnable[Input, Output]): if final is None: final = output else: - final += output - return final + final = final + output + return cast(Output, final) def atransform( self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> AsyncIterator[Output]: if not hasattr(self, "_atransform"): raise NotImplementedError("This runnable does not support async methods.") @@ -2179,7 +2183,7 @@ class RunnableGenerator(Runnable[Input, Output]): self, input: Input, config: Optional[RunnableConfig] = None, - **kwargs: Any | None, + **kwargs: Any, ) -> AsyncIterator[Output]: async def input_aiter() -> AsyncIterator[Input]: yield input @@ -2187,15 +2191,15 @@ class RunnableGenerator(Runnable[Input, Output]): return self.atransform(input_aiter(), config, **kwargs) async def ainvoke( - self, input: Input, config: Optional[RunnableConfig] = None + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> Output: final = None - async for output in self.astream(input, config): + async for output in self.astream(input, config, **kwargs): if final is None: final = output else: - final += output - return final + final = final + output + return cast(Output, final) class RunnableLambda(Runnable[Input, Output]): @@ -2687,7 +2691,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): return RunnableGenerator(thing) elif callable(thing): - return RunnableLambda(thing) + return RunnableLambda(cast(Callable[[Input], Output], thing)) elif isinstance(thing, dict): runnables: Mapping[str, Runnable[Any, Any]] = { key: coerce_to_runnable(r) for key, r in thing.items() diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 6a43e61d69..f697c0328c 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -15,14 +15,7 @@ from typing import ( from typing_extensions import TypedDict from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import ( - Input, - Other, - Output, - Runnable, - RunnableSequence, - coerce_to_runnable, -) +from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable from langchain.schema.runnable.config import ( RunnableConfig, get_config_list, @@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def __or__( - self, - other: Union[ - Runnable[Any, Other], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[RouterInput, Other]: - return RunnableSequence(first=self, last=coerce_to_runnable(other)) - - def __ror__( - self, - other: Union[ - Runnable[Other, Any], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[Other, Output]: - return RunnableSequence(first=coerce_to_runnable(other), last=self) - def invoke( self, input: RouterInput, config: Optional[RunnableConfig] = None ) -> Output: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 316a9ecad6..4a63f92ff2 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -12,7 +12,6 @@ from typing import ( cast, ) from uuid import UUID -from langchain.schema.runnable.base import RunnableGenerator import pytest from freezegun import freeze_time @@ -57,6 +56,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from langchain.schema.runnable.base import RunnableGenerator from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -2876,7 +2876,7 @@ async def test_runnable_gen_transform() -> None: async for i in input: yield i + 1 - chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one + chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one assert chain.input_schema.schema() == {