Remove str() from RunnableConfigurableAlternatives (#11446)

pull/6605/head
Nuno Campos 10 months ago committed by GitHub
parent 656480feb6
commit 79011f835f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,6 +10,7 @@ from typing import (
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
)
@ -71,7 +72,7 @@ class BaseGenerationOutputParser(
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
def OutputType(self) -> Type[T]:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
@ -154,7 +155,7 @@ class BaseOutputParser(
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
def OutputType(self) -> Type[T]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:

@ -3,7 +3,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
import yaml
@ -46,7 +46,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return Union[StringPromptValue, ChatPromptValueConcrete]
@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",

@ -1459,7 +1459,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
return Any
@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> Type[BaseModel]:
if all(
s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values()
@ -1478,7 +1478,7 @@ class RunnableMap(RunnableSerializable[Input, Dict[str, Any]]):
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapOutput",
@ -2065,7 +2065,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
return List[self.bound.InputType] # type: ignore[name-defined]
@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> Type[BaseModel]:
return create_model(
"RunnableEachInput",
__root__=(
@ -2075,11 +2075,11 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
)
@property
def OutputType(self) -> type[List[Output]]:
def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
@property
def output_schema(self) -> type[BaseModel]:
def output_schema(self) -> Type[BaseModel]:
return create_model(
"RunnableEachOutput",
__root__=(
@ -2152,11 +2152,11 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
arbitrary_types_allowed = True
@property
def InputType(self) -> type[Input]:
def InputType(self) -> Type[Input]:
return self.bound.InputType
@property
def OutputType(self) -> type[Output]:
def OutputType(self) -> Type[Output]:
return self.bound.OutputType
@property

@ -274,7 +274,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
which = config.get("configurable", {}).get(self.which.id, self.default_key)
if which == self.default_key:
return self.default
elif which in self.alternatives:

@ -133,7 +133,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
def input_schema(self) -> Type[BaseModel]:
map_input_schema = self.mapper.input_schema
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
@ -142,7 +142,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
def output_schema(self) -> Type[BaseModel]:
map_input_schema = self.mapper.input_schema
map_output_schema = self.mapper.output_schema
if (

Loading…
Cancel
Save