|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import enum
|
|
|
|
|
from abc import abstractmethod
|
|
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
@ -7,7 +8,6 @@ from typing import (
|
|
|
|
|
Dict,
|
|
|
|
|
Iterator,
|
|
|
|
|
List,
|
|
|
|
|
Literal,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
|
Type,
|
|
|
|
@ -32,7 +32,7 @@ from langchain.schema.runnable.utils import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|
|
|
|
bound: RunnableSerializable[Input, Output]
|
|
|
|
|
default: RunnableSerializable[Input, Output]
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
@ -47,19 +47,19 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def InputType(self) -> Type[Input]:
|
|
|
|
|
return self.bound.InputType
|
|
|
|
|
return self.default.InputType
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def OutputType(self) -> Type[Output]:
|
|
|
|
|
return self.bound.OutputType
|
|
|
|
|
return self.default.OutputType
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_schema(self) -> Type[BaseModel]:
|
|
|
|
|
return self.bound.input_schema
|
|
|
|
|
return self.default.input_schema
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def output_schema(self) -> Type[BaseModel]:
|
|
|
|
|
return self.bound.output_schema
|
|
|
|
|
return self.default.output_schema
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def _prepare(
|
|
|
|
@ -88,8 +88,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|
|
|
|
configs = get_config_list(config, len(inputs))
|
|
|
|
|
prepared = [self._prepare(c) for c in configs]
|
|
|
|
|
|
|
|
|
|
if all(p is self.bound for p in prepared):
|
|
|
|
|
return self.bound.batch(
|
|
|
|
|
if all(p is self.default for p in prepared):
|
|
|
|
|
return self.default.batch(
|
|
|
|
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -131,8 +131,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
|
|
|
|
|
configs = get_config_list(config, len(inputs))
|
|
|
|
|
prepared = [self._prepare(c) for c in configs]
|
|
|
|
|
|
|
|
|
|
if all(p is self.bound for p in prepared):
|
|
|
|
|
return await self.bound.abatch(
|
|
|
|
|
if all(p is self.default for p in prepared):
|
|
|
|
|
return await self.default.abatch(
|
|
|
|
|
inputs, config, return_exceptions=return_exceptions, **kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -202,10 +202,10 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|
|
|
|
id=spec.id,
|
|
|
|
|
name=spec.name,
|
|
|
|
|
description=spec.description
|
|
|
|
|
or self.bound.__fields__[field_name].field_info.description,
|
|
|
|
|
or self.default.__fields__[field_name].field_info.description,
|
|
|
|
|
annotation=spec.annotation
|
|
|
|
|
or self.bound.__fields__[field_name].annotation,
|
|
|
|
|
default=getattr(self.bound, field_name),
|
|
|
|
|
or self.default.__fields__[field_name].annotation,
|
|
|
|
|
default=getattr(self.default, field_name),
|
|
|
|
|
)
|
|
|
|
|
for field_name, spec in self.fields.items()
|
|
|
|
|
]
|
|
|
|
@ -213,7 +213,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|
|
|
|
def configurable_fields(
|
|
|
|
|
self, **kwargs: ConfigurableField
|
|
|
|
|
) -> RunnableSerializable[Input, Output]:
|
|
|
|
|
return self.bound.configurable_fields(**{**self.fields, **kwargs})
|
|
|
|
|
return self.default.configurable_fields(**{**self.fields, **kwargs})
|
|
|
|
|
|
|
|
|
|
def _prepare(
|
|
|
|
|
self, config: Optional[RunnableConfig] = None
|
|
|
|
@ -227,9 +227,14 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if configurable:
|
|
|
|
|
return self.bound.__class__(**{**self.bound.dict(), **configurable})
|
|
|
|
|
return self.default.__class__(**{**self.default.dict(), **configurable})
|
|
|
|
|
else:
|
|
|
|
|
return self.bound
|
|
|
|
|
return self.default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Before Python 3.11 native StrEnum is not available
|
|
|
|
|
class StrEnum(str, enum.Enum):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|
|
|
@ -237,21 +242,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|
|
|
|
|
|
|
|
|
alternatives: Dict[str, RunnableSerializable[Input, Output]]
|
|
|
|
|
|
|
|
|
|
default_key: str = "default"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
|
|
|
|
alt_keys = self.alternatives.keys()
|
|
|
|
|
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
|
|
|
|
|
Literal["default"],
|
|
|
|
|
which_enum = StrEnum( # type: ignore[call-overload]
|
|
|
|
|
self.which.name or self.which.id,
|
|
|
|
|
((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
|
|
|
|
|
)
|
|
|
|
|
return [
|
|
|
|
|
ConfigurableFieldSpec(
|
|
|
|
|
id=self.which.id,
|
|
|
|
|
name=self.which.name,
|
|
|
|
|
description=self.which.description,
|
|
|
|
|
annotation=Union[which_keys], # type: ignore
|
|
|
|
|
default="default",
|
|
|
|
|
annotation=which_enum,
|
|
|
|
|
default=self.default_key,
|
|
|
|
|
),
|
|
|
|
|
*self.bound.config_specs,
|
|
|
|
|
*self.default.config_specs,
|
|
|
|
|
] + [s for alt in self.alternatives.values() for s in alt.config_specs]
|
|
|
|
|
|
|
|
|
|
def configurable_fields(
|
|
|
|
@ -259,7 +266,7 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|
|
|
|
) -> RunnableSerializable[Input, Output]:
|
|
|
|
|
return self.__class__(
|
|
|
|
|
which=self.which,
|
|
|
|
|
bound=self.bound.configurable_fields(**kwargs),
|
|
|
|
|
default=self.default.configurable_fields(**kwargs),
|
|
|
|
|
alternatives=self.alternatives,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -267,9 +274,9 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
|
|
|
|
|
self, config: Optional[RunnableConfig] = None
|
|
|
|
|
) -> Runnable[Input, Output]:
|
|
|
|
|
config = config or {}
|
|
|
|
|
which = config.get("configurable", {}).get(self.which.id)
|
|
|
|
|
if not which:
|
|
|
|
|
return self.bound
|
|
|
|
|
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
|
|
|
|
|
if which == self.default_key:
|
|
|
|
|
return self.default
|
|
|
|
|
elif which in self.alternatives:
|
|
|
|
|
return self.alternatives[which]
|
|
|
|
|
else:
|
|
|
|
|