Revert "nc/runnable-dynamic-schemas-from-config" (#12037)

This reverts commit a46eef64a7.

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/12038/head
Nuno Campos 12 months ago committed by GitHub
parent a46eef64a7
commit 85eaa4ccee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,17 +61,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`. chains and cannot return as rich of an output as `__call__`.
""" """
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys} "ChainInput", **{k: (Any, None) for k in self.input_keys}
) )
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys} "ChainOutput", **{k: (Any, None) for k in self.output_keys}

@ -10,7 +10,6 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Field, create_model from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import RunnableConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -29,17 +28,15 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private: input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private: output_key: str = "output_text" #: :meta private:
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsInput", "CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload] **{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
) )
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
return create_model( return create_model(
"CombineDocumentsOutput", "CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload] **{self.output_key: (str, None)}, # type: ignore[call-overload]
@ -170,18 +167,16 @@ class AnalyzeDocumentChain(Chain):
""" """
return self.combine_docs_chain.output_keys return self.combine_docs_chain.output_keys
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
return create_model( return create_model(
"AnalyzeDocumentChain", "AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload] **{self.input_key: (str, None)}, # type: ignore[call-overload]
) )
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.combine_docs_chain.output_schema
return self.combine_docs_chain.get_output_schema(config)
def _call( def _call(
self, self,

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,7 +10,6 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapReduceDocumentsChain(BaseCombineDocumentsChain): class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@ -99,9 +98,8 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return the results of the map steps in the output.""" """Return the results of the map steps in the output."""
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> type[BaseModel]:
) -> Type[BaseModel]:
if self.return_intermediate_steps: if self.return_intermediate_steps:
return create_model( return create_model(
"MapReduceDocumentsOutput", "MapReduceDocumentsOutput",
@ -111,7 +109,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
}, # type: ignore[call-overload] }, # type: ignore[call-overload]
) )
return super().get_output_schema(config) return super().output_schema
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,7 +10,6 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapRerankDocumentsChain(BaseCombineDocumentsChain): class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@ -78,9 +77,8 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> type[BaseModel]:
) -> Type[BaseModel]:
schema: Dict[str, Any] = { schema: Dict[str, Any] = {
self.output_key: (str, None), self.output_key: (str, None),
} }

@ -22,7 +22,6 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BasePromptTemplate, BaseRetriever, Document from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.vectorstore import VectorStore from langchain.schema.vectorstore import VectorStore
# Depending on the memory type and configuration, the chat history format may differ. # Depending on the memory type and configuration, the chat history format may differ.
@ -96,9 +95,8 @@ class BaseConversationalRetrievalChain(Chain):
"""Input keys.""" """Input keys."""
return ["question", "chat_history"] return ["question", "chat_history"]
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
return InputType return InputType
@property @property

@ -45,9 +45,8 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return Union[StringPromptValue, ChatPromptValueConcrete] return Union[StringPromptValue, ChatPromptValueConcrete]
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"PromptInput", "PromptInput",

@ -162,12 +162,6 @@ class Runnable(Generic[Input, Output], ABC):
@property @property
def input_schema(self) -> Type[BaseModel]: def input_schema(self) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
return self.get_input_schema()
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model.""" """The type of input this runnable accepts specified as a pydantic model."""
root_type = self.InputType root_type = self.InputType
@ -180,12 +174,6 @@ class Runnable(Generic[Input, Output], ABC):
@property @property
def output_schema(self) -> Type[BaseModel]: def output_schema(self) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
return self.get_output_schema()
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model.""" """The type of output this runnable produces specified as a pydantic model."""
root_type = self.OutputType root_type = self.OutputType
@ -1056,15 +1044,13 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.last.OutputType return self.last.OutputType
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.first.input_schema
return self.first.get_input_schema(config)
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.last.output_schema
return self.last.get_output_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
@ -1565,11 +1551,10 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return Any return Any
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
if all( if all(
s.get_input_schema(config).schema().get("type", "object") == "object" s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values() for s in self.steps.values()
): ):
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
@ -1578,16 +1563,15 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{ **{
k: (v.annotation, v.default) k: (v.annotation, v.default)
for step in self.steps.values() for step in self.steps.values()
for k, v in step.get_input_schema(config).__fields__.items() for k, v in step.input_schema.__fields__.items()
if k != "__root__" if k != "__root__"
}, },
) )
return super().get_input_schema(config) return super().input_schema
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so. # This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload] return create_model( # type: ignore[call-overload]
"RunnableParallelOutput", "RunnableParallelOutput",
@ -2056,9 +2040,8 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError: except ValueError:
return Any return Any
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
"""The pydantic schema for the input to this runnable.""" """The pydantic schema for the input to this runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc") func = getattr(self, "func", None) or getattr(self, "afunc")
@ -2083,7 +2066,7 @@ class RunnableLambda(Runnable[Input, Output]):
**{key: (Any, None) for key in dict_keys}, # type: ignore **{key: (Any, None) for key in dict_keys}, # type: ignore
) )
return super().get_input_schema(config) return super().input_schema
@property @property
def OutputType(self) -> Any: def OutputType(self) -> Any:
@ -2232,13 +2215,12 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def InputType(self) -> Any: def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined] return List[self.bound.InputType] # type: ignore[name-defined]
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
return create_model( return create_model(
"RunnableEachInput", "RunnableEachInput",
__root__=( __root__=(
List[self.bound.get_input_schema(config)], # type: ignore List[self.bound.input_schema], # type: ignore[name-defined]
None, None,
), ),
) )
@ -2247,14 +2229,12 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def OutputType(self) -> Type[List[Output]]: def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined] return List[self.bound.OutputType] # type: ignore[name-defined]
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model( return create_model(
"RunnableEachOutput", "RunnableEachOutput",
__root__=( __root__=(
List[schema], # type: ignore List[self.bound.output_schema], # type: ignore[name-defined]
None, None,
), ),
) )
@ -2352,15 +2332,13 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.bound.OutputType return self.bound.OutputType
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.bound.input_schema
return self.bound.get_input_schema(merge_configs(self.config, config))
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.bound.output_schema
return self.bound.get_output_schema(merge_configs(self.config, config))
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -130,9 +130,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
"""The namespace of a RunnableBranch is the namespace of its default branch.""" """The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
runnables = ( runnables = (
[self.default] [self.default]
+ [r for _, r in self.branches] + [r for _, r in self.branches]
@ -140,10 +139,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
) )
for runnable in runnables: for runnable in runnables:
if runnable.get_input_schema(config).schema().get("type") is not None: if runnable.input_schema.schema().get("type") is not None:
return runnable.get_input_schema(config) return runnable.input_schema
return super().get_input_schema(config) return super().input_schema
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -60,15 +60,13 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.default.OutputType return self.default.OutputType
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.default.input_schema
return self._prepare(config).get_input_schema(config)
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.default.output_schema
return self._prepare(config).get_output_schema(config)
@abstractmethod @abstractmethod
def _prepare( def _prepare(

@ -53,15 +53,13 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]: def OutputType(self) -> Type[Output]:
return self.runnable.OutputType return self.runnable.OutputType
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.runnable.input_schema
return self.runnable.get_input_schema(config)
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: return self.runnable.output_schema
return self.runnable.get_output_schema(config)
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -268,21 +268,19 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: map_input_schema = self.mapper.input_schema
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__: if not map_input_schema.__custom_root_type__:
# ie. it's a dict # ie. it's a dict
return map_input_schema return map_input_schema
return super().get_input_schema(config) return super().input_schema
def get_output_schema( @property
self, config: Optional[RunnableConfig] = None def output_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]: map_input_schema = self.mapper.input_schema
map_input_schema = self.mapper.get_input_schema(config) map_output_schema = self.mapper.output_schema
map_output_schema = self.mapper.get_output_schema(config)
if ( if (
not map_input_schema.__custom_root_type__ not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__ and not map_output_schema.__custom_root_type__
@ -297,7 +295,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
}, },
) )
return super().get_output_schema(config) return super().output_schema
@property @property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -187,9 +187,8 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
def get_input_schema( @property
self, config: Optional[RunnableConfig] = None def input_schema(self) -> Type[BaseModel]:
) -> Type[BaseModel]:
"""The tool's input schema.""" """The tool's input schema."""
if self.args_schema is not None: if self.args_schema is not None:
return self.args_schema return self.args_schema

@ -800,17 +800,6 @@ def test_configurable_fields() -> None:
text="Hello, John! John!" text="Hello, John! John!"
) )
assert prompt_configurable.with_config(
configurable={"prompt_template": "Hello {name} in {lang}"}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser() chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
assert chain_configurable.invoke({"name": "John"}) == "a" assert chain_configurable.invoke({"name": "John"}) == "a"
@ -845,27 +834,13 @@ def test_configurable_fields() -> None:
assert ( assert (
chain_configurable.with_config( chain_configurable.with_config(
configurable={ configurable={
"prompt_template": "A very good morning to you, {name} {lang}!", "prompt_template": "A very good morning to you, {name}!",
"llm_responses": ["c"], "llm_responses": ["c"],
} }
).invoke({"name": "John", "lang": "en"}) ).invoke({"name": "John"})
== "c" == "c"
) )
assert chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_with_map_configurable: Runnable = prompt_configurable | { chain_with_map_configurable: Runnable = prompt_configurable | {
"llm1": fake_llm_configurable | StrOutputParser(), "llm1": fake_llm_configurable | StrOutputParser(),
"llm2": fake_llm_configurable | StrOutputParser(), "llm2": fake_llm_configurable | StrOutputParser(),

Loading…
Cancel
Save