diff --git a/libs/core/Makefile b/libs/core/Makefile index 56ce4fb846..139b9045f5 100644 --- a/libs/core/Makefile +++ b/libs/core/Makefile @@ -13,7 +13,7 @@ tests: poetry run pytest $(TEST_FILE) test_watch: - poetry run ptw --snapshot-update --now . -- -x tests/unit_tests + poetry run ptw --snapshot-update --now . -- -vv -x tests/unit_tests ###################### diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 094cb99ee6..0963718b02 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -1204,7 +1204,9 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): def configurable_alternatives( self, which: ConfigurableField, + *, default_key: str = "default", + prefix_keys: bool = False, **kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], ) -> RunnableSerializable[Input, Output]: from langchain_core.runnables.configurable import ( @@ -1212,7 +1214,11 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): ) return RunnableConfigurableAlternatives( - which=which, default=self, alternatives=kwargs, default_key=default_key + which=which, + default=self, + alternatives=kwargs, + default_key=default_key, + prefix_keys=prefix_keys, ) diff --git a/libs/core/langchain_core/runnables/configurable.py b/libs/core/langchain_core/runnables/configurable.py index 7d95d15703..1c9756b793 100644 --- a/libs/core/langchain_core/runnables/configurable.py +++ b/libs/core/langchain_core/runnables/configurable.py @@ -220,6 +220,7 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]): annotation=spec.annotation or self.default.__fields__[field_name].annotation, default=getattr(self.default, field_name), + is_shared=spec.is_shared, ) if isinstance(spec, ConfigurableField) else make_options_spec( @@ -298,6 +299,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ] default_key: str = "default" + """The enum value to use for the default option. Defaults to "default".""" + + prefix_keys: bool + """Whether to prefix configurable fields of each alternative with a namespace + of the form ==, eg. a key named "temperature" used by + the alternative named "gpt3" becomes "model==gpt3/temperature".""" @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -313,21 +320,37 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): ), ) _enums_for_spec[self.which] = cast(Type[StrEnum], which_enum) - return [ - ConfigurableFieldSpec( - id=self.which.id, - name=self.which.name, - description=self.which.description, - annotation=which_enum, - default=self.default_key, - ), - *self.default.config_specs, - ] + [ - s - for alt in self.alternatives.values() - if isinstance(alt, RunnableSerializable) - for s in alt.config_specs - ] + return get_unique_config_specs( + # which alternative + [ + ConfigurableFieldSpec( + id=self.which.id, + name=self.which.name, + description=self.which.description, + annotation=which_enum, + default=self.default_key, + is_shared=self.which.is_shared, + ), + ] + # config specs of the default option + + ( + [ + prefix_config_spec(s, f"{self.which.id}=={self.default_key}") + for s in self.default.config_specs + ] + if self.prefix_keys + else self.default.config_specs + ) + # config specs of the alternatives + + [ + prefix_config_spec(s, f"{self.which.id}=={alt_key}") + if self.prefix_keys + else s + for alt_key, alt in self.alternatives.items() + if isinstance(alt, RunnableSerializable) + for s in alt.config_specs + ] + ) def configurable_fields( self, **kwargs: AnyConfigurableField @@ -355,6 +378,23 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): raise ValueError(f"Unknown alternative: {which}") +def prefix_config_spec( + spec: ConfigurableFieldSpec, prefix: str +) -> ConfigurableFieldSpec: + return ( + ConfigurableFieldSpec( + id=f"{prefix}/{spec.id}", + name=spec.name, + description=spec.description, + annotation=spec.annotation, + default=spec.default, + is_shared=spec.is_shared, + ) + if not spec.is_shared + else spec + ) + + def make_options_spec( spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption], description: Optional[str], @@ -377,6 +417,7 @@ def make_options_spec( description=spec.description or description, annotation=enum, default=spec.default, + is_shared=spec.is_shared, ) else: return ConfigurableFieldSpec( @@ -385,4 +426,5 @@ def make_options_spec( description=spec.description or description, annotation=Sequence[enum], # type: ignore[valid-type] default=spec.default, + is_shared=spec.is_shared, ) diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 1ae0d8a5c7..52c04c0e10 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -169,6 +169,7 @@ class RunnableWithMessageHistory(RunnableBindingBase): name="Session ID", description="Unique identifier for a session.", default="", + is_shared=True, ), ] ) diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index aafd9d5945..cd7652bff3 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -257,6 +257,7 @@ class ConfigurableField(NamedTuple): name: Optional[str] = None description: Optional[str] = None annotation: Optional[Any] = None + is_shared: bool = False def __hash__(self) -> int: return hash((self.id, self.annotation)) @@ -271,6 +272,7 @@ class ConfigurableFieldSingleOption(NamedTuple): name: Optional[str] = None description: Optional[str] = None + is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), self.default)) @@ -285,6 +287,7 @@ class ConfigurableFieldMultiOption(NamedTuple): name: Optional[str] = None description: Optional[str] = None + is_shared: bool = False def __hash__(self) -> int: return hash((self.id, tuple(self.options.keys()), tuple(self.default))) @@ -299,12 +302,13 @@ class ConfigurableFieldSpec(NamedTuple): """A field that can be configured by the user. It is a specification of a field.""" id: str - name: Optional[str] - description: Optional[str] - - default: Any annotation: Any + name: Optional[str] = None + description: Optional[str] = None + default: Any = None + is_shared: bool = False + def get_unique_config_specs( specs: Iterable[ConfigurableFieldSpec], diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 6db36b5f57..89522ef3f4 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1020,6 +1020,118 @@ def test_configurable_alts_factory() -> None: assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b" +def test_configurable_fields_prefix_keys() -> None: + fake_chat = FakeListChatModel(responses=["b"]).configurable_fields( + responses=ConfigurableFieldMultiOption( + id="responses", + name="Chat Responses", + options={ + "hello": "A good morning to you!", + "bye": "See you later!", + "helpful": "How can I help you?", + }, + default=["hello", "bye"], + ), + # (sleep is a configurable field in FakeListChatModel) + sleep=ConfigurableField( + id="chat_sleep", + is_shared=True, + ), + ) + fake_llm = ( + FakeListLLM(responses=["a"]) + .configurable_fields( + responses=ConfigurableField( + id="responses", + name="LLM Responses", + description="A list of fake responses for this LLM", + ) + ) + .configurable_alternatives( + ConfigurableField(id="llm", name="LLM"), + chat=fake_chat | StrOutputParser(), + prefix_keys=True, + ) + ) + prompt = PromptTemplate.from_template("Hello, {name}!").configurable_fields( + template=ConfigurableFieldSingleOption( + id="prompt_template", + name="Prompt Template", + description="The prompt template for this chain", + options={ + "hello": "Hello, {name}!", + "good_morning": "A very good morning to you, {name}!", + }, + default="hello", + ) + ) + + chain = prompt | fake_llm + + assert chain.config_schema().schema() == { + "title": "RunnableSequenceConfig", + "type": "object", + "properties": {"configurable": {"$ref": "#/definitions/Configurable"}}, + "definitions": { + "LLM": { + "title": "LLM", + "description": "An enumeration.", + "enum": ["chat", "default"], + "type": "string", + }, + "Chat_Responses": { + "title": "Chat Responses", + "description": "An enumeration.", + "enum": ["hello", "bye", "helpful"], + "type": "string", + }, + "Prompt_Template": { + "title": "Prompt Template", + "description": "An enumeration.", + "enum": ["hello", "good_morning"], + "type": "string", + }, + "Configurable": { + "title": "Configurable", + "type": "object", + "properties": { + "prompt_template": { + "title": "Prompt Template", + "description": "The prompt template for this chain", + "default": "hello", + "allOf": [{"$ref": "#/definitions/Prompt_Template"}], + }, + "llm": { + "title": "LLM", + "default": "default", + "allOf": [{"$ref": "#/definitions/LLM"}], + }, + # not prefixed because marked as shared + "chat_sleep": { + "title": "Chat Sleep", + "type": "number", + }, + # prefixed for "chat" option + "llm==chat/responses": { + "title": "Chat Responses", + "default": ["hello", "bye"], + "type": "array", + "items": {"$ref": "#/definitions/Chat_Responses"}, + }, + # prefixed for "default" option + "llm==default/responses": { + "title": "LLM Responses", + "description": "A list of fake responses for this LLM", + "default": ["a"], + "type": "array", + "items": {"type": "string"}, + }, + }, + }, + }, + } + + def test_configurable_fields_example() -> None: fake_chat = FakeListChatModel(responses=["b"]).configurable_fields( responses=ConfigurableFieldMultiOption(