diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7ee4be6960..9b69a9f711 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: yield from self.stream(final, config, **kwargs) @@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC): else: # Make a best effort to gather, for any type that supports `+` # This method should throw an error if gathering fails. - final += chunk # type: ignore[operator] + final = final + chunk # type: ignore[operator] if got_first_val: async for output in self.astream(final, config, **kwargs): @@ -453,6 +455,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -465,7 +468,9 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - output = call_func_with_variable_args(func, input, run_manager, config) + output = call_func_with_variable_args( + func, input, run_manager, config, **kwargs + ) except BaseException as e: run_manager.on_chain_error(e) raise @@ -486,6 +491,7 @@ class Runnable(Generic[Input, Output], ABC): input: Input, config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Output: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement ainvoke() in subclasses.""" @@ -499,7 +505,7 @@ class Runnable(Generic[Input, Output], ABC): ) try: output = await acall_func_with_variable_args( - func, input, run_manager, config + func, input, run_manager, config, **kwargs ) except BaseException as e: await run_manager.on_chain_error(e) @@ -526,6 +532,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -546,7 +553,6 @@ class Runnable(Generic[Input, Output], ABC): ) ] try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -597,6 +603,7 @@ class Runnable(Generic[Input, Output], ABC): *, return_exceptions: bool = False, run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> List[Output]: """Helper method to transform an Input value to an Output value, with callbacks. Use this method to implement invoke() in subclasses.""" @@ -619,7 +626,6 @@ class Runnable(Generic[Input, Output], ABC): ) ) try: - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = [ patch_config(c, callbacks=rm.get_child()) @@ -668,6 +674,7 @@ class Runnable(Generic[Input, Output], ABC): ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> Iterator[Output]: """Helper method to transform an Iterator of Input values into an Iterator of Output values, with callbacks. @@ -689,7 +696,6 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -706,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -716,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore except TypeError: final_input = None final_input_supported = False @@ -746,6 +752,7 @@ class Runnable(Generic[Input, Output], ABC): ], config: Optional[RunnableConfig], run_type: Optional[str] = None, + **kwargs: Optional[Any], ) -> AsyncIterator[Output]: """Helper method to transform an Async Iterator of Input values into an Async Iterator of Output values, with callbacks. @@ -767,7 +774,6 @@ class Runnable(Generic[Input, Output], ABC): name=config.get("run_name"), ) try: - kwargs: Dict[str, Any] = {} if accepts_config(transformer): kwargs["config"] = patch_config( config, callbacks=run_manager.get_child() @@ -784,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC): final_output = chunk else: try: - final_output += chunk # type: ignore[operator] + final_output = final_output + chunk # type: ignore except TypeError: final_output = None final_output_supported = False @@ -794,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC): final_input = ichunk else: try: - final_input += ichunk # type: ignore[operator] + final_input = final_input + ichunk # type: ignore[operator] except TypeError: final_input = None final_input_supported = False @@ -1311,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): other: Union[ Runnable[Any, Other], Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSequence[Input, Other]: @@ -1331,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): self, other: Union[ Runnable[Other, Any], - Callable[[Any], Other], + Callable[[Other], Any], + Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], ) -> RunnableSequence[Other, Output]: @@ -1751,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = other[key] elif other[key] is not None: - chunk[key] += other[key] + chunk[key] = chunk[key] + other[key] return chunk def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk: @@ -1760,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]): if key not in chunk or chunk[key] is None: chunk[key] = self[key] elif self[key] is not None: - chunk[key] += self[key] + chunk[key] = chunk[key] + self[key] return chunk @@ -2061,6 +2069,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): yield chunk +class RunnableGenerator(Runnable[Input, Output]): + """ + A runnable that runs a generator function. + """ + + def __init__( + self, + transform: Union[ + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], + ], + atransform: Optional[ + Callable[[AsyncIterator[Input]], AsyncIterator[Output]] + ] = None, + ) -> None: + if atransform is not None: + self._atransform = atransform + + if inspect.isasyncgenfunction(transform): + self._atransform = transform + elif inspect.isgeneratorfunction(transform): + self._transform = transform + else: + raise TypeError( + "Expected a generator function type for `transform`." + f"Instead got an unsupported type: {type(transform)}" + ) + + @property + def InputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + params = inspect.signature(func).parameters + first_param = next(iter(params.values()), None) + if first_param and first_param.annotation != inspect.Parameter.empty: + return getattr(first_param.annotation, "__args__", (Any,))[0] + else: + return Any + except ValueError: + return Any + + @property + def OutputType(self) -> Any: + func = getattr(self, "_transform", None) or getattr(self, "_atransform") + try: + sig = inspect.signature(func) + return ( + getattr(sig.return_annotation, "__args__", (Any,))[0] + if sig.return_annotation != inspect.Signature.empty + else Any + ) + except ValueError: + return Any + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RunnableGenerator): + if hasattr(self, "_transform") and hasattr(other, "_transform"): + return self._transform == other._transform + elif hasattr(self, "_atransform") and hasattr(other, "_atransform"): + return self._atransform == other._atransform + else: + return False + else: + return False + + def __repr__(self) -> str: + return "RunnableGenerator(...)" + + def transform( + self, + input: Iterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Output]: + return self._transform_stream_with_config( + input, self._transform, config, **kwargs + ) + + def stream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> Iterator[Output]: + return self.transform(iter([input]), config, **kwargs) + + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + for output in self.stream(input, config, **kwargs): + if final is None: + final = output + else: + final = final + output + return cast(Output, final) + + def atransform( + self, + input: AsyncIterator[Input], + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Output]: + if not hasattr(self, "_atransform"): + raise NotImplementedError("This runnable does not support async methods.") + + return self._atransform_stream_with_config( + input, self._atransform, config, **kwargs + ) + + def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Any, + ) -> AsyncIterator[Output]: + async def input_aiter() -> AsyncIterator[Input]: + yield input + + return self.atransform(input_aiter(), config, **kwargs) + + async def ainvoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: + final = None + async for output in self.astream(input, config, **kwargs): + if final is None: + final = output + else: + final = final + output + return cast(Output, final) + + class RunnableLambda(Runnable[Input, Output]): """ A runnable that runs a callable. @@ -2538,6 +2679,8 @@ RunnableLike = Union[ Runnable[Input, Output], Callable[[Input], Output], Callable[[Input], Awaitable[Output]], + Callable[[Iterator[Input]], Iterator[Output]], + Callable[[AsyncIterator[Input]], AsyncIterator[Output]], Mapping[str, Any], ] @@ -2545,8 +2688,10 @@ RunnableLike = Union[ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]: if isinstance(thing, Runnable): return thing + elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing): + return RunnableGenerator(thing) elif callable(thing): - return RunnableLambda(thing) + return RunnableLambda(cast(Callable[[Input], Output], thing)) elif isinstance(thing, dict): runnables: Mapping[str, Runnable[Any, Any]] = { key: coerce_to_runnable(r) for key, r in thing.items() diff --git a/libs/langchain/langchain/schema/runnable/config.py b/libs/langchain/langchain/schema/runnable/config.py index 6ae120ad7f..06d979cff0 100644 --- a/libs/langchain/langchain/schema/runnable/config.py +++ b/libs/langchain/langchain/schema/runnable/config.py @@ -152,9 +152,9 @@ def call_func_with_variable_args( input: Input, run_manager: CallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func): @@ -174,9 +174,9 @@ async def acall_func_with_variable_args( input: Input, run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Output: """Call function that may optionally accept a run_manager and/or config.""" - kwargs: Dict[str, Any] = {} if accepts_config(func): kwargs["config"] = patch_config(config, callbacks=run_manager.get_child()) if accepts_run_manager(func): diff --git a/libs/langchain/langchain/schema/runnable/router.py b/libs/langchain/langchain/schema/runnable/router.py index 6a43e61d69..f697c0328c 100644 --- a/libs/langchain/langchain/schema/runnable/router.py +++ b/libs/langchain/langchain/schema/runnable/router.py @@ -15,14 +15,7 @@ from typing import ( from typing_extensions import TypedDict from langchain.load.serializable import Serializable -from langchain.schema.runnable.base import ( - Input, - Other, - Output, - Runnable, - RunnableSequence, - coerce_to_runnable, -) +from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable from langchain.schema.runnable.config import ( RunnableConfig, get_config_list, @@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def __or__( - self, - other: Union[ - Runnable[Any, Other], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[RouterInput, Other]: - return RunnableSequence(first=self, last=coerce_to_runnable(other)) - - def __ror__( - self, - other: Union[ - Runnable[Other, Any], - Callable[[Any], Other], - Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]], - Mapping[str, Any], - ], - ) -> RunnableSequence[Other, Output]: - return RunnableSequence(first=coerce_to_runnable(other), last=self) - def invoke( self, input: RouterInput, config: Optional[RunnableConfig] = None ) -> Output: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 52538fdd5c..f632a23e9d 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,6 +1,16 @@ import sys from operator import itemgetter -from typing import Any, Dict, List, Optional, Sequence, Union, cast +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Sequence, + Union, + cast, +) from uuid import UUID import pytest @@ -46,6 +56,7 @@ from langchain.schema.runnable import ( RunnableSequence, RunnableWithFallbacks, ) +from langchain.schema.runnable.base import RunnableGenerator from langchain.tools.base import BaseTool, tool from langchain.tools.json.tool import JsonListKeysTool, JsonSpec @@ -2809,3 +2820,81 @@ async def test_tool_from_runnable() -> None: "title": "PromptInput", "type": "object", } + + +@pytest.mark.asyncio +async def test_runnable_gen() -> None: + """Test that a generator can be used as a runnable.""" + + def gen(input: Iterator[Any]) -> Iterator[int]: + yield 1 + yield 2 + yield 3 + + runnable = RunnableGenerator(gen) + + assert runnable.input_schema.schema() == {"title": "RunnableGeneratorInput"} + assert runnable.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert runnable.invoke(None) == 6 + assert list(runnable.stream(None)) == [1, 2, 3] + assert runnable.batch([None, None]) == [6, 6] + + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: + yield 1 + yield 2 + yield 3 + + arunnable = RunnableGenerator(agen) + + assert await arunnable.ainvoke(None) == 6 + assert [p async for p in arunnable.astream(None)] == [1, 2, 3] + assert await arunnable.abatch([None, None]) == [6, 6] + + +@pytest.mark.asyncio +async def test_runnable_gen_transform() -> None: + """Test that a generator can be used as a runnable.""" + + def gen_indexes(length_iter: Iterator[int]) -> Iterator[int]: + for i in range(next(length_iter)): + yield i + + async def agen_indexes(length_iter: AsyncIterator[int]) -> AsyncIterator[int]: + async for length in length_iter: + for i in range(length): + yield i + + def plus_one(input: Iterator[int]) -> Iterator[int]: + for i in input: + yield i + 1 + + async def aplus_one(input: AsyncIterator[int]) -> AsyncIterator[int]: + async for i in input: + yield i + 1 + + chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one + achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one + + assert chain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert chain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + assert achain.input_schema.schema() == { + "title": "RunnableGeneratorInput", + "type": "integer", + } + assert achain.output_schema.schema() == { + "title": "RunnableGeneratorOutput", + "type": "integer", + } + + assert list(chain.stream(3)) == [1, 2, 3] + assert [p async for p in achain.astream(4)] == [1, 2, 3, 4]