diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 8f50b5a922..818bcf3c38 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1013,7 +1013,7 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): self, which: ConfigurableField, default_key: str = "default", - **kwargs: Runnable[Input, Output], + **kwargs: Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], ) -> RunnableSerializable[Input, Output]: from langchain.schema.runnable.configurable import ( RunnableConfigurableAlternatives, diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 2f1c6f7706..d9591d209c 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -6,6 +6,7 @@ from abc import abstractmethod from typing import ( Any, AsyncIterator, + Callable, Dict, Iterator, List, @@ -287,7 +288,10 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): which: ConfigurableField - alternatives: Dict[str, RunnableSerializable[Input, Output]] + alternatives: Dict[ + str, + Union[Runnable[Input, Output], Callable[[], Runnable[Input, Output]]], + ] default_key: str = "default" @@ -314,7 +318,12 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): default=self.default_key, ), *self.default.config_specs, - ] + [s for alt in self.alternatives.values() for s in alt.config_specs] + ] + [ + s + for alt in self.alternatives.values() + if isinstance(alt, RunnableSerializable) + for s in alt.config_specs + ] def configurable_fields( self, **kwargs: AnyConfigurableField @@ -333,7 +342,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]): if which == self.default_key: return self.default elif which in self.alternatives: - return self.alternatives[which] + alt = self.alternatives[which] + if isinstance(alt, Runnable): + return alt + else: + return alt() else: raise ValueError(f"Unknown alternative: {which}") diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 9b08523395..5bc55cca31 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,4 +1,5 @@ import sys +from functools import partial from operator import itemgetter from typing import ( Any, @@ -925,6 +926,17 @@ def test_configurable_fields() -> None: ).invoke({"name": "John"}) == {"llm1": "c", "llm2": "c", "llm3": "d"} +def test_configurable_alts_factory() -> None: + fake_llm = FakeListLLM(responses=["a"]).configurable_alternatives( + ConfigurableField(id="llm", name="LLM"), + chat=partial(FakeListLLM, responses=["b"]), + ) + + assert fake_llm.invoke("...") == "a" + + assert fake_llm.with_config(configurable={"llm": "chat"}).invoke("...") == "b" + + def test_configurable_fields_example() -> None: fake_chat = FakeListChatModel(responses=["b"]).configurable_fields( responses=ConfigurableFieldMultiOption(