Add option to prefix config keys in configurable_alts (#13714)

pull/13321/head^2
Nuno Campos 7 months ago committed by GitHub
parent 4ce5254442
commit 8a3e0c9afa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -13,7 +13,7 @@ tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -x tests/unit_tests
poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests
######################

@ -1204,7 +1204,9 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
def configurable_alternatives(
self,
which: ConfigurableField,
*,
default_key: str = "default",
prefix_keys: bool = False,
**kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]],
) -> RunnableSerializable[Input, Output]:
from langchain_core.runnables.configurable import (
@ -1212,7 +1214,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
)
return RunnableConfigurableAlternatives(
which=which, default=self, alternatives=kwargs, default_key=default_key
which=which,
default=self,
alternatives=kwargs,
default_key=default_key,
prefix_keys=prefix_keys,
)

@ -220,6 +220,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
annotation=spec.annotation
or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name),
is_shared=spec.is_shared,
)
if isinstance(spec, ConfigurableField)
else make_options_spec(
@ -298,6 +299,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
]
default_key: str = "default"
"""The enum value to use for the default option. Defaults to "default"."""
prefix_keys: bool
"""Whether to prefix configurable fields of each alternative with a namespace
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@ -313,21 +320,37 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
),
)
_enums_for_spec[self.which] = cast(Type[StrEnum], which_enum)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=which_enum,
default=self.default_key,
),
*self.default.config_specs,
] + [
s
for alt in self.alternatives.values()
if isinstance(alt, RunnableSerializable)
for s in alt.config_specs
]
return get_unique_config_specs(
# which alternative
[
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=which_enum,
default=self.default_key,
is_shared=self.which.is_shared,
),
]
# config specs of the default option
+ (
[
prefix_config_spec(s, f"{self.which.id}=={self.default_key}")
for s in self.default.config_specs
]
if self.prefix_keys
else self.default.config_specs
)
# config specs of the alternatives
+ [
prefix_config_spec(s, f"{self.which.id}=={alt_key}")
if self.prefix_keys
else s
for alt_key, alt in self.alternatives.items()
if isinstance(alt, RunnableSerializable)
for s in alt.config_specs
]
)
def configurable_fields(
self, **kwargs: AnyConfigurableField
@ -355,6 +378,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
raise ValueError(f"Unknown alternative: {which}")
def prefix_config_spec(
spec: ConfigurableFieldSpec, prefix: str
) -> ConfigurableFieldSpec:
return (
ConfigurableFieldSpec(
id=f"{prefix}/{spec.id}",
name=spec.name,
description=spec.description,
annotation=spec.annotation,
default=spec.default,
is_shared=spec.is_shared,
)
if not spec.is_shared
else spec
)
def make_options_spec(
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
description: Optional[str],
@ -377,6 +417,7 @@ def make_options_spec(
description=spec.description or description,
annotation=enum,
default=spec.default,
is_shared=spec.is_shared,
)
else:
return ConfigurableFieldSpec(
@ -385,4 +426,5 @@ def make_options_spec(
description=spec.description or description,
annotation=Sequence[enum], # type: ignore[valid-type]
default=spec.default,
is_shared=spec.is_shared,
)

@ -169,6 +169,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
name="Session ID",
description="Unique identifier for a session.",
default="",
is_shared=True,
),
]
)

@ -257,6 +257,7 @@ class ConfigurableField(NamedTuple):
name: Optional[str] = None
description: Optional[str] = None
annotation: Optional[Any] = None
is_shared: bool = False
def __hash__(self) -> int:
return hash((self.id, self.annotation))
@ -271,6 +272,7 @@ class ConfigurableFieldSingleOption(NamedTuple):
name: Optional[str] = None
description: Optional[str] = None
is_shared: bool = False
def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), self.default))
@ -285,6 +287,7 @@ class ConfigurableFieldMultiOption(NamedTuple):
name: Optional[str] = None
description: Optional[str] = None
is_shared: bool = False
def __hash__(self) -> int:
return hash((self.id, tuple(self.options.keys()), tuple(self.default)))
@ -299,12 +302,13 @@ class ConfigurableFieldSpec(NamedTuple):
"""A field that can be configured by the user. It is a specification of a field."""
id: str
name: Optional[str]
description: Optional[str]
default: Any
annotation: Any
name: Optional[str] = None
description: Optional[str] = None
default: Any = None
is_shared: bool = False
def get_unique_config_specs(
specs: Iterable[ConfigurableFieldSpec],

@ -1020,6 +1020,118 @@ def test_configurable_alts_factory() -> None:
assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b"
def test_configurable_fields_prefix_keys() -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(
id="responses",
name="Chat Responses",
options={
"hello": "A good morning to you!",
"bye": "See you later!",
"helpful": "How can I help you?",
},
default=["hello", "bye"],
),
# (sleep is a configurable field in FakeListChatModel)
sleep=ConfigurableField(
id="chat_sleep",
is_shared=True,
),
)
fake_llm = (
FakeListLLM(responses=["a"])
.configurable_fields(
responses=ConfigurableField(
id="responses",
name="LLM Responses",
description="A list of fake responses for this LLM",
)
)
.configurable_alternatives(
ConfigurableField(id="llm", name="LLM"),
chat=fake_chat | StrOutputParser(),
prefix_keys=True,
)
)
prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields(
template=ConfigurableFieldSingleOption(
id="prompt_template",
name="Prompt Template",
description="The prompt template for this chain",
options={
"hello": "Hello, {name}!",
"good_morning": "A very good morning to you, {name}!",
},
default="hello",
)
)
chain = prompt | fake_llm
assert chain.config_schema().schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Chat_Responses": {
"title": "Chat Responses",
"description": "An enumeration.",
"enum": ["hello", "bye", "helpful"],
"type": "string",
},
"Prompt_Template": {
"title": "Prompt Template",
"description": "An enumeration.",
"enum": ["hello", "good_morning"],
"type": "string",
},
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"prompt_template": {
"title": "Prompt Template",
"description": "The prompt template for this chain",
"default": "hello",
"allOf": [{"$ref": "#/definitions/Prompt_Template"}],
},
"llm": {
"title": "LLM",
"default": "default",
"allOf": [{"$ref": "#/definitions/LLM"}],
},
# not prefixed because marked as shared
"chat_sleep": {
"title": "Chat Sleep",
"type": "number",
},
# prefixed for "chat" option
"llm==chat/responses": {
"title": "Chat Responses",
"default": ["hello", "bye"],
"type": "array",
"items": {"$ref": "#/definitions/Chat_Responses"},
},
# prefixed for "default" option
"llm==default/responses": {
"title": "LLM Responses",
"description": "A list of fake responses for this LLM",
"default": ["a"],
"type": "array",
"items": {"type": "string"},
},
},
},
},
}
def test_configurable_fields_example() -> None:
fake_chat = FakeListChatModel(responses=["b"]).configurable_fields(
responses=ConfigurableFieldMultiOption(

Loading…
Cancel
Save