From 79011f835ffae9557fea46f8ae77723a024f7662 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 5 Oct 2023 18:40:00 +0100 Subject: [PATCH] Remove str() from RunnableConfigurableAlternatives (#11446) --- libs/langchain/langchain/schema/output_parser.py | 5 +++-- libs/langchain/langchain/schema/prompt_template.py | 4 ++-- libs/langchain/langchain/schema/runnable/base.py | 14 +++++++------- .../langchain/schema/runnable/configurable.py | 2 +- .../langchain/schema/runnable/passthrough.py | 4 ++-- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index c675dfe49b..1bf65c9122 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -10,6 +10,7 @@ from typing import ( Iterator, List, Optional, + Type, TypeVar, Union, ) @@ -71,7 +72,7 @@ class BaseGenerationOutputParser( return Union[str, AnyMessage] @property - def OutputType(self) -> type[T]: + def OutputType(self) -> Type[T]: # even though mypy complains this isn't valid, # it is good enough for pydantic to build the schema from return T # type: ignore[misc] @@ -154,7 +155,7 @@ class BaseOutputParser( return Union[str, AnyMessage] @property - def OutputType(self) -> type[T]: + def OutputType(self) -> Type[T]: for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) if type_args and len(type_args) == 1: diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index c9dabd572a..9ac8164bb6 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -3,7 +3,7 @@ from __future__ import annotations import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union import yaml @@ -46,7 +46,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): return Union[StringPromptValue, ChatPromptValueConcrete] @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "PromptInput", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 6bfef3ed44..f99de2d176 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1459,7 +1459,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): return Any @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: if all( s.input_schema.schema().get("type", "object") == "object" for s in self.steps.values() @@ -1478,7 +1478,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]): return super().input_schema @property - def output_schema(self) -> type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "RunnableMapOutput", @@ -2065,7 +2065,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): return List[self.bound.InputType] # type: ignore[name-defined] @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: return create_model( "RunnableEachInput", __root__=( @@ -2075,11 +2075,11 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): ) @property - def OutputType(self) -> type[List[Output]]: + def OutputType(self) -> Type[List[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] @property - def output_schema(self) -> type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: return create_model( "RunnableEachOutput", __root__=( @@ -2152,11 +2152,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]): arbitrary_types_allowed = True @property - def InputType(self) -> type[Input]: + def InputType(self) -> Type[Input]: return self.bound.InputType @property - def OutputType(self) -> type[Output]: + def OutputType(self) -> Type[Output]: return self.bound.OutputType @property diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 0f7b07cdda..7933455eaa 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -274,7 +274,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): self, config: Optional[RunnableConfig] = None ) -> Runnable[Input, Output]: config = config or {} - which = str(config.get("configurable", {}).get(self.which.id, self.default_key)) + which = config.get("configurable", {}).get(self.which.id, self.default_key) if which == self.default_key: return self.default elif which in self.alternatives: diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index ce893b6a52..79b743c5b4 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -133,7 +133,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return cls.__module__.split(".")[:-1] @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: map_input_schema = self.mapper.input_schema if not map_input_schema.__custom_root_type__: # ie. it's a dict @@ -142,7 +142,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): return super().input_schema @property - def output_schema(self) -> type[BaseModel]: + def output_schema(self) -> Type[BaseModel]: map_input_schema = self.mapper.input_schema map_output_schema = self.mapper.output_schema if (