diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 83a83346ab..43af5d31a1 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -61,15 +61,17 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): chains and cannot return as rich of an output as `__call__`. """ - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainInput", **{k: (Any, None) for k in self.input_keys} ) - @property - def output_schema(self) -> Type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainOutput", **{k: (Any, None) for k in self.output_keys} diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index 63cf836d46..ea28e99add 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -10,6 +10,7 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.docstore.document import Document from langchain.pydantic_v1 import BaseModel, Field, create_model +from langchain.schema.runnable.config import RunnableConfig from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter @@ -28,15 +29,17 @@ class BaseCombineDocumentsChain(Chain, ABC): input_key: str = "input_documents" #: :meta private: output_key: str = "output_text" #: :meta private: - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: return create_model( "CombineDocumentsInput", **{self.input_key: (List[Document], None)}, # type: ignore[call-overload] ) - @property - def output_schema(self) -> Type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: return create_model( "CombineDocumentsOutput", **{self.output_key: (str, None)}, # type: ignore[call-overload] @@ -167,16 +170,18 @@ class AnalyzeDocumentChain(Chain): """ return self.combine_docs_chain.output_keys - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: return create_model( "AnalyzeDocumentChain", **{self.input_key: (str, None)}, # type: ignore[call-overload] ) - @property - def output_schema(self) -> Type[BaseModel]: - return self.combine_docs_chain.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.combine_docs_chain.get_output_schema(config) def _call( self, diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index f593f66db0..fcbb721e1b 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -10,6 +10,7 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator +from langchain.schema.runnable.config import RunnableConfig class MapReduceDocumentsChain(BaseCombineDocumentsChain): @@ -98,8 +99,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return_intermediate_steps: bool = False """Return the results of the map steps in the output.""" - @property - def output_schema(self) -> type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: if self.return_intermediate_steps: return create_model( "MapReduceDocumentsOutput", @@ -109,7 +111,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): }, # type: ignore[call-overload] ) - return super().output_schema + return super().get_output_schema(config) @property def output_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 4af56bc6ca..717222ed95 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -10,6 +10,7 @@ from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.output_parsers.regex import RegexParser from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator +from langchain.schema.runnable.config import RunnableConfig class MapRerankDocumentsChain(BaseCombineDocumentsChain): @@ -77,8 +78,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): extra = Extra.forbid arbitrary_types_allowed = True - @property - def output_schema(self) -> type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: schema: Dict[str, Any] = { self.output_key: (str, None), } diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 97528bd8e5..5347b29a61 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -22,6 +22,7 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage +from langchain.schema.runnable.config import RunnableConfig from langchain.schema.vectorstore import VectorStore # Depending on the memory type and configuration, the chat history format may differ. @@ -95,8 +96,9 @@ class BaseConversationalRetrievalChain(Chain): """Input keys.""" return ["question", "chat_history"] - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: return InputType @property diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 6f976beb70..58d7a38660 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -45,8 +45,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): return Union[StringPromptValue, ChatPromptValueConcrete] - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "PromptInput", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 3e3f73f171..8f50b5a922 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -162,6 +162,12 @@ class Runnable(Generic[Input, Output], ABC): @property def input_schema(self) -> Type[BaseModel]: + """The type of input this runnable accepts specified as a pydantic model.""" + return self.get_input_schema() + + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: """The type of input this runnable accepts specified as a pydantic model.""" root_type = self.InputType @@ -174,6 +180,12 @@ class Runnable(Generic[Input, Output], ABC): @property def output_schema(self) -> Type[BaseModel]: + """The type of output this runnable produces specified as a pydantic model.""" + return self.get_output_schema() + + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: """The type of output this runnable produces specified as a pydantic model.""" root_type = self.OutputType @@ -1044,13 +1056,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.last.OutputType - @property - def input_schema(self) -> Type[BaseModel]: - return self.first.input_schema + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.first.get_input_schema(config) - @property - def output_schema(self) -> Type[BaseModel]: - return self.last.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.last.get_output_schema(config) @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: @@ -1551,10 +1565,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): return Any - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: if all( - s.input_schema.schema().get("type", "object") == "object" + s.get_input_schema(config).schema().get("type", "object") == "object" for s in self.steps.values() ): # This is correct, but pydantic typings/mypy don't think so. @@ -1563,15 +1578,16 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): **{ k: (v.annotation, v.default) for step in self.steps.values() - for k, v in step.input_schema.__fields__.items() + for k, v in step.get_input_schema(config).__fields__.items() if k != "__root__" }, ) - return super().input_schema + return super().get_input_schema(config) - @property - def output_schema(self) -> Type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "RunnableParallelOutput", @@ -2040,8 +2056,9 @@ class RunnableLambda(Runnable[Input, Output]): except ValueError: return Any - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: """The pydantic schema for the input to this runnable.""" func = getattr(self, "func", None) or getattr(self, "afunc") @@ -2066,7 +2083,7 @@ class RunnableLambda(Runnable[Input, Output]): **{key: (Any, None) for key in dict_keys}, # type: ignore ) - return super().input_schema + return super().get_input_schema(config) @property def OutputType(self) -> Any: @@ -2215,12 +2232,13 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def InputType(self) -> Any: return List[self.bound.InputType] # type: ignore[name-defined] - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: return create_model( "RunnableEachInput", __root__=( - List[self.bound.input_schema], # type: ignore[name-defined] + List[self.bound.get_input_schema(config)], # type: ignore None, ), ) @@ -2229,12 +2247,14 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def OutputType(self) -> Type[List[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] - @property - def output_schema(self) -> Type[BaseModel]: + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + schema = self.bound.get_output_schema(config) return create_model( "RunnableEachOutput", __root__=( - List[self.bound.output_schema], # type: ignore[name-defined] + List[schema], # type: ignore None, ), ) @@ -2332,13 +2352,15 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.bound.OutputType - @property - def input_schema(self) -> Type[BaseModel]: - return self.bound.input_schema + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.bound.get_input_schema(merge_configs(self.config, config)) - @property - def output_schema(self) -> Type[BaseModel]: - return self.bound.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.bound.get_output_schema(merge_configs(self.config, config)) @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index 18e7aa1767..efbdcfcccb 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -130,8 +130,9 @@ class RunnableBranch(RunnableSerializable[Input, Output]): """The namespace of a RunnableBranch is the namespace of its default branch.""" return cls.__module__.split(".")[:-1] - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] @@ -139,10 +140,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ) for runnable in runnables: - if runnable.input_schema.schema().get("type") is not None: - return runnable.input_schema + if runnable.get_input_schema(config).schema().get("type") is not None: + return runnable.get_input_schema(config) - return super().input_schema + return super().get_input_schema(config) @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 44102f5cb4..2f1c6f7706 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -60,13 +60,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.default.OutputType - @property - def input_schema(self) -> Type[BaseModel]: - return self.default.input_schema + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self._prepare(config).get_input_schema(config) - @property - def output_schema(self) -> Type[BaseModel]: - return self.default.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self._prepare(config).get_output_schema(config) @abstractmethod def _prepare( diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 60ab497b62..4689fa56e7 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -53,13 +53,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.runnable.OutputType - @property - def input_schema(self) -> Type[BaseModel]: - return self.runnable.input_schema + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.runnable.get_input_schema(config) - @property - def output_schema(self) -> Type[BaseModel]: - return self.runnable.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + return self.runnable.get_output_schema(config) @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 6fb59c8e1c..6f369d753c 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -268,19 +268,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - @property - def input_schema(self) -> Type[BaseModel]: - map_input_schema = self.mapper.input_schema + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + map_input_schema = self.mapper.get_input_schema(config) if not map_input_schema.__custom_root_type__: # ie. it's a dict return map_input_schema - return super().input_schema + return super().get_input_schema(config) - @property - def output_schema(self) -> Type[BaseModel]: - map_input_schema = self.mapper.input_schema - map_output_schema = self.mapper.output_schema + def get_output_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: + map_input_schema = self.mapper.get_input_schema(config) + map_output_schema = self.mapper.get_output_schema(config) if ( not map_input_schema.__custom_root_type__ and not map_output_schema.__custom_root_type__ @@ -295,7 +297,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): }, ) - return super().output_schema + return super().get_output_schema(config) @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 0eac196a94..2d3cc7ac10 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -187,8 +187,9 @@ class ChildTool(BaseTool): # --- Runnable --- - @property - def input_schema(self) -> Type[BaseModel]: + def get_input_schema( + self, config: Optional[RunnableConfig] = None + ) -> Type[BaseModel]: """The tool's input schema.""" if self.args_schema is not None: return self.args_schema diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index a31e234ec6..9b08523395 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -800,6 +800,17 @@ def test_configurable_fields() -> None: text="Hello, John! John!" ) + assert prompt_configurable.with_config( + configurable={"prompt_template": "Hello {name} in {lang}"} + ).input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "lang": {"title": "Lang", "type": "string"}, + "name": {"title": "Name", "type": "string"}, + }, + } + chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser() assert chain_configurable.invoke({"name": "John"}) == "a" @@ -834,13 +845,27 @@ def test_configurable_fields() -> None: assert ( chain_configurable.with_config( configurable={ - "prompt_template": "A very good morning to you, {name}!", + "prompt_template": "A very good morning to you, {name} {lang}!", "llm_responses": ["c"], } - ).invoke({"name": "John"}) + ).invoke({"name": "John", "lang": "en"}) == "c" ) + assert chain_configurable.with_config( + configurable={ + "prompt_template": "A very good morning to you, {name} {lang}!", + "llm_responses": ["c"], + } + ).input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "lang": {"title": "Lang", "type": "string"}, + "name": {"title": "Name", "type": "string"}, + }, + } + chain_with_map_configurable: Runnable = prompt_configurable | { "llm1": fake_llm_configurable | StrOutputParser(), "llm2": fake_llm_configurable | StrOutputParser(),