Make it easier to subclass runnable binding with custom init args (#13189)

pull/13052/head^2
Nuno Campos 11 months ago committed by GitHub
parent 7f1964b264
commit 8d6faf5665
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,9 @@
from typing import Any, Optional
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
class HubRunnable(RunnableBinding[Input, Output]):
class HubRunnable(RunnableBindingBase[Input, Output]):
"""
An instance of a runnable stored in the LangChain Hub.
"""

@ -5,7 +5,8 @@ from typing_extensions import TypedDict
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
from langchain.schema.runnable import RouterRunnable, Runnable
from langchain.schema.runnable.base import RunnableBindingBase
class OpenAIFunction(TypedDict):
@ -19,7 +20,7 @@ class OpenAIFunction(TypedDict):
"""The parameters to the function."""
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
"""A runnable that routes to the selected function."""
functions: Optional[List[OpenAIFunction]]

@ -2581,11 +2581,6 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2659,7 +2654,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
class RunnableBinding(RunnableSerializable[Input, Output]):
class RunnableBindingBase(RunnableSerializable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.
"""
@ -2749,11 +2744,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.bound.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2762,93 +2752,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
config=self.config,
kwargs={**self.kwargs, **kwargs},
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_listeners(
self,
*,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda config: {
"callbacks": [
RootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
custom_input_type=input_type
if input_type is not None
else self.custom_input_type,
custom_output_type=output_type
if output_type is not None
else self.custom_output_type,
)
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = merge_configs(self.config, *configs)
return merge_configs(config, *(f(config) for f in self.config_factories))
@ -2972,7 +2875,97 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
yield item
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
class RunnableBinding(RunnableBindingBase[Input, Output]):
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
config=self.config,
kwargs={**self.kwargs, **kwargs},
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_config(
self,
config: Optional[RunnableConfig] = None,
# Sadly Unpack is not well supported by mypy so this will have to be untyped
**kwargs: Any,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_listeners(
self,
*,
on_start: Optional[Listener] = None,
on_end: Optional[Listener] = None,
on_error: Optional[Listener] = None,
) -> Runnable[Input, Output]:
"""
Bind lifecycle listeners to a Runnable, returning a new Runnable.
on_start: Called before the runnable starts running, with the Run object.
on_end: Called after the runnable finishes running, with the Run object.
on_error: Called if the runnable throws an error, with the Run object.
The Run object contains information about the run, including its id,
type, input, output, error, start_time, end_time, and any tags or metadata
added to the run.
"""
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
config_factories=[
lambda config: {
"callbacks": [
RootListenersTracer(
config=config,
on_start=on_start,
on_end=on_end,
on_error=on_error,
)
],
}
],
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_types(
self,
input_type: Optional[Union[Type[Input], BaseModel]] = None,
output_type: Optional[Union[Type[Output], BaseModel]] = None,
) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound,
kwargs=self.kwargs,
config=self.config,
custom_input_type=input_type
if input_type is not None
else self.custom_input_type,
custom_output_type=output_type
if output_type is not None
else self.custom_output_type,
)
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
return self.__class__(
bound=self.bound.with_retry(**kwargs),
kwargs=self.kwargs,
config=self.config,
)
RunnableLike = Union[
Runnable[Input, Output],

@ -119,11 +119,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
for spec in step.config_specs
)
def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)
@classmethod
def is_lc_serializable(cls) -> bool:
return True

@ -21,7 +21,7 @@ from tenacity import (
wait_exponential_jitter,
)
from langchain.schema.runnable.base import Input, Output, RunnableBinding
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
from langchain.schema.runnable.config import RunnableConfig, patch_config
if TYPE_CHECKING:
@ -34,7 +34,7 @@ if TYPE_CHECKING:
U = TypeVar("U")
class RunnableRetry(RunnableBinding[Input, Output]):
class RunnableRetry(RunnableBindingBase[Input, Output]):
"""Retry a Runnable if it fails.
A RunnableRetry helps can be used to add retry logic to any object

Loading…
Cancel
Save