Add RunnableGenerator

pull/11214/head
Nuno Campos 1 year ago
parent ca5293bf54
commit b67db8deaa

@ -453,6 +453,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -465,7 +466,9 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
output = call_func_with_variable_args(func, input, run_manager, config)
output = call_func_with_variable_args(
func, input, run_manager, config, **kwargs
)
except BaseException as e:
run_manager.on_chain_error(e)
raise
@ -486,6 +489,7 @@ class Runnable(Generic[Input, Output], ABC):
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Output:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement ainvoke() in subclasses."""
@ -499,7 +503,7 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
output = await acall_func_with_variable_args(
func, input, run_manager, config
func, input, run_manager, config, **kwargs
)
except BaseException as e:
await run_manager.on_chain_error(e)
@ -526,6 +530,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -546,7 +551,6 @@ class Runnable(Generic[Input, Output], ABC):
)
]
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
@ -597,6 +601,7 @@ class Runnable(Generic[Input, Output], ABC):
*,
return_exceptions: bool = False,
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> List[Output]:
"""Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses."""
@ -619,7 +624,6 @@ class Runnable(Generic[Input, Output], ABC):
)
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = [
patch_config(c, callbacks=rm.get_child())
@ -668,6 +672,7 @@ class Runnable(Generic[Input, Output], ABC):
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> Iterator[Output]:
"""Helper method to transform an Iterator of Input values into an Iterator of
Output values, with callbacks.
@ -689,7 +694,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
@ -746,6 +750,7 @@ class Runnable(Generic[Input, Output], ABC):
],
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
"""Helper method to transform an Async Iterator of Input values into an Async
Iterator of Output values, with callbacks.
@ -767,7 +772,6 @@ class Runnable(Generic[Input, Output], ABC):
name=config.get("run_name"),
)
try:
kwargs: Dict[str, Any] = {}
if accepts_config(transformer):
kwargs["config"] = patch_config(
config, callbacks=run_manager.get_child()
@ -2061,6 +2065,139 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
yield chunk
class RunnableGenerator(Runnable[Input, Output]):
"""
A runnable that runs a generator function.
"""
def __init__(
self,
transform: Union[
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
],
atransform: Optional[
Callable[[AsyncIterator[Input]], AsyncIterator[Output]]
] = None,
) -> None:
if atransform is not None:
self._atransform = atransform
if inspect.isasyncgenfunction(transform):
self._atransform = transform
elif inspect.isgeneratorfunction(transform):
self._transform = transform
else:
raise TypeError(
"Expected a generator function type for `transform`."
f"Instead got an unsupported type: {type(transform)}"
)
@property
def InputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any
@property
def OutputType(self) -> Type[Output]:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
return (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
return self._transform == other._transform
elif hasattr(self, "_atransform") and hasattr(other, "_atransform"):
return self._atransform == other._atransform
else:
return False
else:
return False
def __repr__(self) -> str:
return "RunnableGenerator(...)"
def transform(
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
)
def stream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
for output in self.stream(input, config, **kwargs):
if final is None:
final = output
else:
final += output
return final
async def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")
return self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
)
async def astream(
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
return self.atransform(input_aiter(), config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
final = None
async for output in self.astream(input, config):
if final is None:
final = output
else:
final += output
return final
class RunnableLambda(Runnable[Input, Output]):
"""
A runnable that runs a callable.
@ -2538,6 +2675,8 @@ RunnableLike = Union[
Runnable[Input, Output],
Callable[[Input], Output],
Callable[[Input], Awaitable[Output]],
Callable[[Iterator[Input]], Iterator[Output]],
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
Mapping[str, Any],
]
@ -2545,6 +2684,8 @@ RunnableLike = Union[
def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):

@ -152,9 +152,9 @@ def call_func_with_variable_args(
input: Input,
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):
@ -174,9 +174,9 @@ async def acall_func_with_variable_args(
input: Input,
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Output:
"""Call function that may optionally accept a run_manager and/or config."""
kwargs: Dict[str, Any] = {}
if accepts_config(func):
kwargs["config"] = patch_config(config, callbacks=run_manager.get_child())
if accepts_run_manager(func):

Loading…
Cancel
Save