diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 43af5d31a1..83a83346ab 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -61,17 +61,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): chains and cannot return as rich of an output as `__call__`. """ - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainInput", **{k: (Any, None) for k in self.input_keys} ) - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def output_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "ChainOutput", **{k: (Any, None) for k in self.output_keys} diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index ea28e99add..63cf836d46 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -10,7 +10,6 @@ from langchain.callbacks.manager import ( from langchain.chains.base import Chain from langchain.docstore.document import Document from langchain.pydantic_v1 import BaseModel, Field, create_model -from langchain.schema.runnable.config import RunnableConfig from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter @@ -29,17 +28,15 @@ class BaseCombineDocumentsChain(Chain, ABC): input_key: str = "input_documents" #: :meta private: output_key: str = "output_text" #: :meta private: - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: return create_model( "CombineDocumentsInput", **{self.input_key: (List[Document], None)}, # type: ignore[call-overload] ) - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def output_schema(self) -> Type[BaseModel]: return create_model( "CombineDocumentsOutput", **{self.output_key: (str, None)}, # type: ignore[call-overload] @@ -170,18 +167,16 @@ class AnalyzeDocumentChain(Chain): """ return self.combine_docs_chain.output_keys - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: return create_model( "AnalyzeDocumentChain", **{self.input_key: (str, None)}, # type: ignore[call-overload] ) - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.combine_docs_chain.get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: + return self.combine_docs_chain.output_schema def _call( self, diff --git a/libs/langchain/langchain/chains/combine_documents/map_reduce.py b/libs/langchain/langchain/chains/combine_documents/map_reduce.py index fcbb721e1b..f593f66db0 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_reduce.py +++ b/libs/langchain/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -10,7 +10,6 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator -from langchain.schema.runnable.config import RunnableConfig class MapReduceDocumentsChain(BaseCombineDocumentsChain): @@ -99,9 +98,8 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): return_intermediate_steps: bool = False """Return the results of the map steps in the output.""" - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def output_schema(self) -> type[BaseModel]: if self.return_intermediate_steps: return create_model( "MapReduceDocumentsOutput", @@ -111,7 +109,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): }, # type: ignore[call-overload] ) - return super().get_output_schema(config) + return super().output_schema @property def output_keys(self) -> List[str]: diff --git a/libs/langchain/langchain/chains/combine_documents/map_rerank.py b/libs/langchain/langchain/chains/combine_documents/map_rerank.py index 717222ed95..4af56bc6ca 100644 --- a/libs/langchain/langchain/chains/combine_documents/map_rerank.py +++ b/libs/langchain/langchain/chains/combine_documents/map_rerank.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from langchain.callbacks.manager import Callbacks from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -10,7 +10,6 @@ from langchain.chains.llm import LLMChain from langchain.docstore.document import Document from langchain.output_parsers.regex import RegexParser from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator -from langchain.schema.runnable.config import RunnableConfig class MapRerankDocumentsChain(BaseCombineDocumentsChain): @@ -78,9 +77,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain): extra = Extra.forbid arbitrary_types_allowed = True - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def output_schema(self) -> type[BaseModel]: schema: Dict[str, Any] = { self.output_key: (str, None), } diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index 5347b29a61..97528bd8e5 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -22,7 +22,6 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage -from langchain.schema.runnable.config import RunnableConfig from langchain.schema.vectorstore import VectorStore # Depending on the memory type and configuration, the chat history format may differ. @@ -96,9 +95,8 @@ class BaseConversationalRetrievalChain(Chain): """Input keys.""" return ["question", "chat_history"] - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: return InputType @property diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index 58d7a38660..6f976beb70 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -45,9 +45,8 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): return Union[StringPromptValue, ChatPromptValueConcrete] - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "PromptInput", diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 8f50b5a922..3e3f73f171 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -162,12 +162,6 @@ class Runnable(Generic[Input, Output], ABC): @property def input_schema(self) -> Type[BaseModel]: - """The type of input this runnable accepts specified as a pydantic model.""" - return self.get_input_schema() - - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: """The type of input this runnable accepts specified as a pydantic model.""" root_type = self.InputType @@ -180,12 +174,6 @@ class Runnable(Generic[Input, Output], ABC): @property def output_schema(self) -> Type[BaseModel]: - """The type of output this runnable produces specified as a pydantic model.""" - return self.get_output_schema() - - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: """The type of output this runnable produces specified as a pydantic model.""" root_type = self.OutputType @@ -1056,15 +1044,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.last.OutputType - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.first.get_input_schema(config) + @property + def input_schema(self) -> Type[BaseModel]: + return self.first.input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.last.get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: + return self.last.output_schema @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: @@ -1565,11 +1551,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): return Any - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: if all( - s.get_input_schema(config).schema().get("type", "object") == "object" + s.input_schema.schema().get("type", "object") == "object" for s in self.steps.values() ): # This is correct, but pydantic typings/mypy don't think so. @@ -1578,16 +1563,15 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): **{ k: (v.annotation, v.default) for step in self.steps.values() - for k, v in step.get_input_schema(config).__fields__.items() + for k, v in step.input_schema.__fields__.items() if k != "__root__" }, ) - return super().get_input_schema(config) + return super().input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def output_schema(self) -> Type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] "RunnableParallelOutput", @@ -2056,9 +2040,8 @@ class RunnableLambda(Runnable[Input, Output]): except ValueError: return Any - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: """The pydantic schema for the input to this runnable.""" func = getattr(self, "func", None) or getattr(self, "afunc") @@ -2083,7 +2066,7 @@ class RunnableLambda(Runnable[Input, Output]): **{key: (Any, None) for key in dict_keys}, # type: ignore ) - return super().get_input_schema(config) + return super().input_schema @property def OutputType(self) -> Any: @@ -2232,13 +2215,12 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def InputType(self) -> Any: return List[self.bound.InputType] # type: ignore[name-defined] - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: return create_model( "RunnableEachInput", __root__=( - List[self.bound.get_input_schema(config)], # type: ignore + List[self.bound.input_schema], # type: ignore[name-defined] None, ), ) @@ -2247,14 +2229,12 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def OutputType(self) -> Type[List[Output]]: return List[self.bound.OutputType] # type: ignore[name-defined] - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - schema = self.bound.get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: return create_model( "RunnableEachOutput", __root__=( - List[schema], # type: ignore + List[self.bound.output_schema], # type: ignore[name-defined] None, ), ) @@ -2352,15 +2332,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.bound.OutputType - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.bound.get_input_schema(merge_configs(self.config, config)) + @property + def input_schema(self) -> Type[BaseModel]: + return self.bound.input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.bound.get_output_schema(merge_configs(self.config, config)) + @property + def output_schema(self) -> Type[BaseModel]: + return self.bound.output_schema @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index efbdcfcccb..18e7aa1767 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -130,9 +130,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]): """The namespace of a RunnableBranch is the namespace of its default branch.""" return cls.__module__.split(".")[:-1] - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches] @@ -140,10 +139,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]): ) for runnable in runnables: - if runnable.get_input_schema(config).schema().get("type") is not None: - return runnable.get_input_schema(config) + if runnable.input_schema.schema().get("type") is not None: + return runnable.input_schema - return super().get_input_schema(config) + return super().input_schema @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/configurable.py b/libs/langchain/langchain/schema/runnable/configurable.py index 2f1c6f7706..44102f5cb4 100644 --- a/libs/langchain/langchain/schema/runnable/configurable.py +++ b/libs/langchain/langchain/schema/runnable/configurable.py @@ -60,15 +60,13 @@ class DynamicRunnable(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.default.OutputType - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self._prepare(config).get_input_schema(config) + @property + def input_schema(self) -> Type[BaseModel]: + return self.default.input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self._prepare(config).get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: + return self.default.output_schema @abstractmethod def _prepare( diff --git a/libs/langchain/langchain/schema/runnable/fallbacks.py b/libs/langchain/langchain/schema/runnable/fallbacks.py index 4689fa56e7..60ab497b62 100644 --- a/libs/langchain/langchain/schema/runnable/fallbacks.py +++ b/libs/langchain/langchain/schema/runnable/fallbacks.py @@ -53,15 +53,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): def OutputType(self) -> Type[Output]: return self.runnable.OutputType - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.runnable.get_input_schema(config) + @property + def input_schema(self) -> Type[BaseModel]: + return self.runnable.input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - return self.runnable.get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: + return self.runnable.output_schema @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 6f369d753c..6fb59c8e1c 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -268,21 +268,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): def get_lc_namespace(cls) -> List[str]: return cls.__module__.split(".")[:-1] - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - map_input_schema = self.mapper.get_input_schema(config) + @property + def input_schema(self) -> Type[BaseModel]: + map_input_schema = self.mapper.input_schema if not map_input_schema.__custom_root_type__: # ie. it's a dict return map_input_schema - return super().get_input_schema(config) + return super().input_schema - def get_output_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: - map_input_schema = self.mapper.get_input_schema(config) - map_output_schema = self.mapper.get_output_schema(config) + @property + def output_schema(self) -> Type[BaseModel]: + map_input_schema = self.mapper.input_schema + map_output_schema = self.mapper.output_schema if ( not map_input_schema.__custom_root_type__ and not map_output_schema.__custom_root_type__ @@ -297,7 +295,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): }, ) - return super().get_output_schema(config) + return super().output_schema @property def config_specs(self) -> Sequence[ConfigurableFieldSpec]: diff --git a/libs/langchain/langchain/tools/base.py b/libs/langchain/langchain/tools/base.py index 2d3cc7ac10..0eac196a94 100644 --- a/libs/langchain/langchain/tools/base.py +++ b/libs/langchain/langchain/tools/base.py @@ -187,9 +187,8 @@ class ChildTool(BaseTool): # --- Runnable --- - def get_input_schema( - self, config: Optional[RunnableConfig] = None - ) -> Type[BaseModel]: + @property + def input_schema(self) -> Type[BaseModel]: """The tool's input schema.""" if self.args_schema is not None: return self.args_schema diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 9b08523395..a31e234ec6 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -800,17 +800,6 @@ def test_configurable_fields() -> None: text="Hello, John! John!" ) - assert prompt_configurable.with_config( - configurable={"prompt_template": "Hello {name} in {lang}"} - ).input_schema.schema() == { - "title": "PromptInput", - "type": "object", - "properties": { - "lang": {"title": "Lang", "type": "string"}, - "name": {"title": "Name", "type": "string"}, - }, - } - chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser() assert chain_configurable.invoke({"name": "John"}) == "a" @@ -845,27 +834,13 @@ def test_configurable_fields() -> None: assert ( chain_configurable.with_config( configurable={ - "prompt_template": "A very good morning to you, {name} {lang}!", + "prompt_template": "A very good morning to you, {name}!", "llm_responses": ["c"], } - ).invoke({"name": "John", "lang": "en"}) + ).invoke({"name": "John"}) == "c" ) - assert chain_configurable.with_config( - configurable={ - "prompt_template": "A very good morning to you, {name} {lang}!", - "llm_responses": ["c"], - } - ).input_schema.schema() == { - "title": "PromptInput", - "type": "object", - "properties": { - "lang": {"title": "Lang", "type": "string"}, - "name": {"title": "Name", "type": "string"}, - }, - } - chain_with_map_configurable: Runnable = prompt_configurable | { "llm1": fake_llm_configurable | StrOutputParser(), "llm2": fake_llm_configurable | StrOutputParser(),