nc/runnable-dynamic-schemas-from-config

pull/12037/head
Nuno Campos 12 months ago
parent d392e030be
commit a46eef64a7

@ -61,15 +61,17 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)
@property
def output_schema(self) -> Type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}

@ -10,6 +10,7 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import RunnableConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -28,15 +29,17 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
@ -167,16 +170,18 @@ class AnalyzeDocumentChain(Chain):
"""
return self.combine_docs_chain.output_keys
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config)
def _call(
self,

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,6 +10,7 @@ from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@ -98,8 +99,9 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""
@property
def output_schema(self) -> type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
@ -109,7 +111,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
}, # type: ignore[call-overload]
)
return super().output_schema
return super().get_output_schema(config)
@property
def output_keys(self) -> List[str]:

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -10,6 +10,7 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@ -77,8 +78,9 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def output_schema(self) -> type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}

@ -22,6 +22,7 @@ from langchain.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.vectorstore import VectorStore
# Depending on the memory type and configuration, the chat history format may differ.
@ -95,8 +96,9 @@ class BaseConversationalRetrievalChain(Chain):
"""Input keys."""
return ["question", "chat_history"]
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return InputType
@property

@ -45,8 +45,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return Union[StringPromptValue, ChatPromptValueConcrete]
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",

@ -162,6 +162,12 @@ class Runnable(Generic[Input, Output], ABC):
@property
def input_schema(self) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
return self.get_input_schema()
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
root_type = self.InputType
@ -174,6 +180,12 @@ class Runnable(Generic[Input, Output], ABC):
@property
def output_schema(self) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
return self.get_output_schema()
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
root_type = self.OutputType
@ -1044,13 +1056,15 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]:
return self.last.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.first.input_schema
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.first.get_input_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.last.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
@ -1551,10 +1565,11 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
return Any
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
if all(
s.input_schema.schema().get("type", "object") == "object"
s.get_input_schema(config).schema().get("type", "object") == "object"
for s in self.steps.values()
):
# This is correct, but pydantic typings/mypy don't think so.
@ -1563,15 +1578,16 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{
k: (v.annotation, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
for k, v in step.get_input_schema(config).__fields__.items()
if k != "__root__"
},
)
return super().input_schema
return super().get_input_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelOutput",
@ -2040,8 +2056,9 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError:
return Any
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The pydantic schema for the input to this runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc")
@ -2066,7 +2083,7 @@ class RunnableLambda(Runnable[Input, Output]):
**{key: (Any, None) for key in dict_keys}, # type: ignore
)
return super().input_schema
return super().get_input_schema(config)
@property
def OutputType(self) -> Any:
@ -2215,12 +2232,13 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"RunnableEachInput",
__root__=(
List[self.bound.input_schema], # type: ignore[name-defined]
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
)
@ -2229,12 +2247,14 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
@property
def output_schema(self) -> Type[BaseModel]:
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model(
"RunnableEachOutput",
__root__=(
List[self.bound.output_schema], # type: ignore[name-defined]
List[schema], # type: ignore
None,
),
)
@ -2332,13 +2352,15 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]:
return self.bound.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.bound.get_input_schema(merge_configs(self.config, config))
@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.bound.get_output_schema(merge_configs(self.config, config))
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -130,8 +130,9 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
@ -139,10 +140,10 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
if runnable.get_input_schema(config).schema().get("type") is not None:
return runnable.get_input_schema(config)
return super().input_schema
return super().get_input_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -60,13 +60,15 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]:
return self.default.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.default.input_schema
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_input_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.default.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_output_schema(config)
@abstractmethod
def _prepare(

@ -53,13 +53,15 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.runnable.get_input_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.runnable.get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -268,19 +268,21 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> Type[BaseModel]:
map_input_schema = self.mapper.input_schema
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
if not map_input_schema.__custom_root_type__:
# ie. it's a dict
return map_input_schema
return super().input_schema
return super().get_input_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
map_input_schema = self.mapper.input_schema
map_output_schema = self.mapper.output_schema
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
map_input_schema = self.mapper.get_input_schema(config)
map_output_schema = self.mapper.get_output_schema(config)
if (
not map_input_schema.__custom_root_type__
and not map_output_schema.__custom_root_type__
@ -295,7 +297,7 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
},
)
return super().output_schema
return super().get_output_schema(config)
@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:

@ -187,8 +187,9 @@ class ChildTool(BaseTool):
# --- Runnable ---
@property
def input_schema(self) -> Type[BaseModel]:
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema

@ -800,6 +800,17 @@ def test_configurable_fields() -> None:
text="Hello, John! John!"
)
assert prompt_configurable.with_config(
configurable={"prompt_template": "Hello {name} in {lang}"}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_configurable = prompt_configurable | fake_llm_configurable | StrOutputParser()
assert chain_configurable.invoke({"name": "John"}) == "a"
@ -834,13 +845,27 @@ def test_configurable_fields() -> None:
assert (
chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name}!",
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).invoke({"name": "John"})
).invoke({"name": "John", "lang": "en"})
== "c"
)
assert chain_configurable.with_config(
configurable={
"prompt_template": "A very good morning to you, {name} {lang}!",
"llm_responses": ["c"],
}
).input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"lang": {"title": "Lang", "type": "string"},
"name": {"title": "Name", "type": "string"},
},
}
chain_with_map_configurable: Runnable = prompt_configurable | {
"llm1": fake_llm_configurable | StrOutputParser(),
"llm2": fake_llm_configurable | StrOutputParser(),

Loading…
Cancel
Save