From b0893c7c6a4954151acb80c42f373494e7880179 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Oct 2023 16:32:41 +0100 Subject: [PATCH] Use an enum for configurable_alternatives to make the generated json schema nicer (#11350) --- .../langchain/schema/runnable/base.py | 20 +++---- .../langchain/schema/runnable/configurable.py | 59 +++++++++++-------- .../langchain/schema/runnable/fallbacks.py | 4 +- .../schema/runnable/test_runnable.py | 25 ++++---- 4 files changed, 56 insertions(+), 52 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 77d9b472ec..57a93d8f7c 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -129,9 +129,7 @@ class Runnable(Generic[Input, Output], ABC): def config_specs(self) -> Sequence[ConfigurableFieldSpec]: return [] - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: + def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]: class _Config: arbitrary_types_allowed = True @@ -150,7 +148,7 @@ class Runnable(Generic[Input, Output], ABC): for spec in config_specs }, ) - if config_specs + if config_specs and "configurable" in include else None ) @@ -161,7 +159,7 @@ class Runnable(Generic[Input, Output], ABC): **{ field_name: (field_type, None) for field_name, field_type in RunnableConfig.__annotations__.items() - if field_name in include + if field_name in [i for i in include if i != "configurable"] }, ) @@ -873,7 +871,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): "available keys are {self.__fields__.keys()}" ) - return RunnableConfigurableFields(bound=self, fields=kwargs) + return RunnableConfigurableFields(default=self, fields=kwargs) def configurable_alternatives( self, @@ -885,7 +883,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): ) return RunnableConfigurableAlternatives( - which=which, bound=self, alternatives=kwargs + which=which, default=self, alternatives=kwargs ) @@ -2051,9 +2049,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def config_specs(self) -> Sequence[ConfigurableFieldSpec]: return self.bound.config_specs - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: + def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]: return self.bound.config_schema(include=include) @classmethod @@ -2132,9 +2128,7 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def config_specs(self) -> Sequence[ConfigurableFieldSpec]: return self.bound.config_specs - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: + def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]: return self.bound.config_schema(include=include) @classmethod diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 1afdcc0ef9..e58af4a82b 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum from abc import abstractmethod from typing import ( Any, @@ -7,7 +8,6 @@ from typing import ( Dict, Iterator, List, - Literal, Optional, Sequence, Type, @@ -32,7 +32,7 @@ from langchain.schema.runnable.utils import ( class DynamicRunnable(RunnableSerializable[Input, Output]): - bound: RunnableSerializable[Input, Output] + default: RunnableSerializable[Input, Output] class Config: arbitrary_types_allowed = True @@ -47,19 +47,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): @property def InputType(self) -> Type[Input]: - return self.bound.InputType + return self.default.InputType @property def OutputType(self) -> Type[Output]: - return self.bound.OutputType + return self.default.OutputType @property def input_schema(self) -> Type[BaseModel]: - return self.bound.input_schema + return self.default.input_schema @property def output_schema(self) -> Type[BaseModel]: - return self.bound.output_schema + return self.default.output_schema @abstractmethod def _prepare( @@ -88,8 +88,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): configs = get_config_list(config, len(inputs)) prepared = [self._prepare(c) for c in configs] - if all(p is self.bound for p in prepared): - return self.bound.batch( + if all(p is self.default for p in prepared): + return self.default.batch( inputs, config, return_exceptions=return_exceptions, **kwargs ) @@ -131,8 +131,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): configs = get_config_list(config, len(inputs)) prepared = [self._prepare(c) for c in configs] - if all(p is self.bound for p in prepared): - return await self.bound.abatch( + if all(p is self.default for p in prepared): + return await self.default.abatch( inputs, config, return_exceptions=return_exceptions, **kwargs ) @@ -202,10 +202,10 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): id=spec.id, name=spec.name, description=spec.description - or self.bound.__fields__[field_name].field_info.description, + or self.default.__fields__[field_name].field_info.description, annotation=spec.annotation - or self.bound.__fields__[field_name].annotation, - default=getattr(self.bound, field_name), + or self.default.__fields__[field_name].annotation, + default=getattr(self.default, field_name), ) for field_name, spec in self.fields.items() ] @@ -213,7 +213,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): def configurable_fields( self, **kwargs: ConfigurableField ) -> RunnableSerializable[Input, Output]: - return self.bound.configurable_fields(**{**self.fields, **kwargs}) + return self.default.configurable_fields(**{**self.fields, **kwargs}) def _prepare( self, config: Optional[RunnableConfig] = None @@ -227,9 +227,14 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): } if configurable: - return self.bound.__class__(**{**self.bound.dict(), **configurable}) + return self.default.__class__(**{**self.default.dict(), **configurable}) else: - return self.bound + return self.default + + +# Before Python 3.11 native StrEnum is not available +class StrEnum(str, enum.Enum): + pass class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): @@ -237,21 +242,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): alternatives: Dict[str, RunnableSerializable[Input, Output]] + default_key: str = "default" + @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: - alt_keys = self.alternatives.keys() - which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore - Literal["default"], + which_enum = StrEnum( # type: ignore[call-overload] + self.which.name or self.which.id, + ((v, v) for v in list(self.alternatives.keys()) + [self.default_key]), ) return [ ConfigurableFieldSpec( id=self.which.id, name=self.which.name, description=self.which.description, - annotation=Union[which_keys], # type: ignore - default="default", + annotation=which_enum, + default=self.default_key, ), - *self.bound.config_specs, + *self.default.config_specs, ] + [s for alt in self.alternatives.values() for s in alt.config_specs] def configurable_fields( @@ -259,7 +266,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ) -> RunnableSerializable[Input, Output]: return self.__class__( which=self.which, - bound=self.bound.configurable_fields(**kwargs), + default=self.default.configurable_fields(**kwargs), alternatives=self.alternatives, ) @@ -267,9 +274,9 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): self, config: Optional[RunnableConfig] = None ) -> Runnable[Input, Output]: config = config or {} - which = config.get("configurable", {}).get(self.which.id) - if not which: - return self.bound + which = str(config.get("configurable", {}).get(self.which.id, self.default_key)) + if which == self.default_key: + return self.default elif which in self.alternatives: return self.alternatives[which] else: diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 239800e1bb..60ab497b62 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -69,9 +69,7 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): for spec in step.config_specs ) - def config_schema( - self, *, include: Optional[Sequence[str]] = None - ) -> Type[BaseModel]: + def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]: return self.runnable.config_schema(include=include) @classmethod diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index c389b42648..1fae13287f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -563,7 +563,7 @@ def test_configurable_fields() -> None: assert fake_llm_configurable.invoke("...") == "a" - assert fake_llm_configurable.config_schema().schema() == { + assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == { "title": "RunnableConfigurableFieldsConfig", "type": "object", "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, @@ -606,7 +606,7 @@ def test_configurable_fields() -> None: text="Hello, John!" ) - assert prompt_configurable.config_schema().schema() == { + assert prompt_configurable.config_schema(include=["configurable"]).schema() == { "title": "RunnableConfigurableFieldsConfig", "type": "object", "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, @@ -638,7 +638,7 @@ def test_configurable_fields() -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - assert chain_configurable.config_schema().schema() == { + assert chain_configurable.config_schema(include=["configurable"]).schema() == { "title": "RunnableSequenceConfig", "type": "object", "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, @@ -690,7 +690,9 @@ def test_configurable_fields() -> None: "llm3": "a", } - assert chain_with_map_configurable.config_schema().schema() == { + assert chain_with_map_configurable.config_schema( + include=["configurable"] + ).schema() == { "title": "RunnableSequenceConfig", "type": "object", "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, @@ -760,11 +762,17 @@ def test_configurable_fields_example() -> None: assert chain_configurable.invoke({"name": "John"}) == "a" - assert chain_configurable.config_schema().schema() == { + assert chain_configurable.config_schema(include=["configurable"]).schema() == { "title": "RunnableSequenceConfig", "type": "object", "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, "definitions": { + "LLM": { + "title": "LLM", + "description": "An enumeration.", + "enum": ["chat", "default"], + "type": "string", + }, "Configurable": { "title": "Configurable", "type": "object", @@ -772,10 +780,7 @@ def test_configurable_fields_example() -> None: "llm": { "title": "LLM", "default": "default", - "anyOf": [ - {"enum": ["chat"], "type": "string"}, - {"enum": ["default"], "type": "string"}, - ], + "allOf": [{"$ref": "#/definitions/LLM"}], }, "llm_responses": { "title": "LLM Responses", @@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None: "type": "string", }, }, - } + }, }, }