From d5aeff706acefbd889d7813414c3d25a981b24e3 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 15 Nov 2023 13:12:57 +0000 Subject: [PATCH] Make it easier to subclass RunnableEach (#13346) --- .../langchain/schema/runnable/base.py | 81 +++++++++++-------- .../langchain/schema/runnable/configurable.py | 36 +++++---- 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7cf6d2af90..0e21347807 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2535,10 +2535,12 @@ class RunnableLambda(Runnable[Input, Output]): return await super().ainvoke(input, config) -class RunnableEach(RunnableSerializable[List[Input], List[Output]]): +class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): """ A runnable that delegates calls to another runnable with each element of the input sequence. + + Use only if creating a new RunnableEach subclass with different __init__ args. """ bound: Runnable[Input, Output] @@ -2589,38 +2591,6 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: - return RunnableEach(bound=self.bound.bind(**kwargs)) - - def with_config( - self, config: Optional[RunnableConfig] = None, **kwargs: Any - ) -> RunnableEach[Input, Output]: - return RunnableEach(bound=self.bound.with_config(config, **kwargs)) - - def with_listeners( - self, - *, - on_start: Optional[Listener] = None, - on_end: Optional[Listener] = None, - on_error: Optional[Listener] = None, - ) -> RunnableEach[Input, Output]: - """ - Bind lifecycle listeners to a Runnable, returning a new Runnable. - - on_start: Called before the runnable starts running, with the Run object. - on_end: Called after the runnable finishes running, with the Run object. - on_error: Called if the runnable throws an error, with the Run object. - - The Run object contains information about the run, including its id, - type, input, output, error, start_time, end_time, and any tags or metadata - added to the run. - """ - return RunnableEach( - bound=self.bound.with_listeners( - on_start=on_start, on_end=on_end, on_error=on_error - ) - ) - def _invoke( self, inputs: List[Input], @@ -2654,9 +2624,50 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): return await self._acall_with_config(self._ainvoke, input, config, **kwargs) +class RunnableEach(RunnableEachBase[Input, Output]): + """ + A runnable that delegates calls to another runnable + with each element of the input sequence. + """ + + def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.bind(**kwargs)) + + def with_config( + self, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.with_config(config, **kwargs)) + + def with_listeners( + self, + *, + on_start: Optional[Listener] = None, + on_end: Optional[Listener] = None, + on_error: Optional[Listener] = None, + ) -> RunnableEach[Input, Output]: + """ + Bind lifecycle listeners to a Runnable, returning a new Runnable. + + on_start: Called before the runnable starts running, with the Run object. + on_end: Called after the runnable finishes running, with the Run object. + on_error: Called if the runnable throws an error, with the Run object. + + The Run object contains information about the run, including its id, + type, input, output, error, start_time, end_time, and any tags or metadata + added to the run. + """ + return RunnableEach( + bound=self.bound.with_listeners( + on_start=on_start, on_end=on_end, on_error=on_error + ) + ) + + class RunnableBindingBase(RunnableSerializable[Input, Output]): """ A runnable that delegates calls to another runnable with a set of kwargs. + + Use only if creating a new RunnableBinding subclass with different __init__ args. """ bound: Runnable[Input, Output] @@ -2879,6 +2890,10 @@ RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig) class RunnableBinding(RunnableBindingBase[Input, Output]): + """ + A runnable that delegates calls to another runnable with a set of kwargs. + """ + def bind(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound, diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index d9591d209c..f1161efba6 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -34,6 +34,7 @@ from langchain.schema.runnable.utils import ( Input, Output, gather_with_concurrency, + get_unique_config_specs, ) @@ -209,22 +210,25 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: - return [ - ConfigurableFieldSpec( - id=spec.id, - name=spec.name, - description=spec.description - or self.default.__fields__[field_name].field_info.description, - annotation=spec.annotation - or self.default.__fields__[field_name].annotation, - default=getattr(self.default, field_name), - ) - if isinstance(spec, ConfigurableField) - else make_options_spec( - spec, self.default.__fields__[field_name].field_info.description - ) - for field_name, spec in self.fields.items() - ] + return get_unique_config_specs( + [ + ConfigurableFieldSpec( + id=spec.id, + name=spec.name, + description=spec.description + or self.default.__fields__[field_name].field_info.description, + annotation=spec.annotation + or self.default.__fields__[field_name].annotation, + default=getattr(self.default, field_name), + ) + if isinstance(spec, ConfigurableField) + else make_options_spec( + spec, self.default.__fields__[field_name].field_info.description + ) + for field_name, spec in self.fields.items() + ] + + list(self.default.config_specs) + ) def configurable_fields( self, **kwargs: AnyConfigurableField