Add input/output schemas to runnables (#11063)

This adds `input_schema` and `output_schema` properties to all
runnables, which are Pydantic models for the input and output types
respectively. These are inferred from the structure of the Runnable as
much as possible, the only manual typing needed is
- optionally add type hints to lambdas (which get translated to
input/output schemas)
- optionally add type hint to RunnablePassthrough

These schemas can then be used to create JSON Schema descriptions of
input and output types, see the tests

- [x] Ensure no InputType and OutputType in our classes use abstract
base classes (replace with union of subclasses)
- [x] Implement in BaseChain and LLMChain
- [x] Implement in RunnableBranch
- [x] Implement in RunnableBinding, RunnableMap, RunnablePassthrough,
RunnableEach, RunnableRouter
- [x] Implement in LLM, Prompt, Chat Model, Output Parser, Retriever
- [x] Implement in RunnableLambda from function signature
- [x] Implement in Tool

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11161/head
Nuno Campos 11 months ago committed by GitHub
parent b05bb9e136
commit cfa2203c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union
import yaml
@ -22,7 +22,13 @@ from langchain.callbacks.manager import (
)
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator, validator
from langchain.pydantic_v1 import (
BaseModel,
Field,
create_model,
root_validator,
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""
@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)
@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)
def invoke(
self,
input: Dict[str, Any],

@ -1,7 +1,7 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)
@property
def input_keys(self) -> List[str]:
"""Expect input key.
@ -153,6 +167,17 @@ class AnalyzeDocumentChain(Chain):
"""
return self.combine_docs_chain.output_keys
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema
def _call(
self,
inputs: Dict[str, str],

@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""
@property
def output_schema(self) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
}, # type: ignore[call-overload]
)
return super().output_schema
@property
def output_keys(self) -> List[str]:
"""Expect input key.

@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import Extra, root_validator
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@ -77,6 +77,18 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid
arbitrary_types_allowed = True
@property
def output_schema(self) -> type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})
return create_model("MapRerankOutput", **schema)
@property
def output_keys(self) -> List[str]:
"""Expect input key.

@ -11,6 +11,7 @@ from typing import (
List,
Optional,
Sequence,
Union,
cast,
)
@ -37,9 +38,14 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
@ -107,6 +113,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
# --- Runnable methods ---
@property
def OutputType(self) -> Any:
"""Get the input type for this runnable."""
return Union[
HumanMessageChunk,
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
SystemMessageChunk,
]
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input

@ -38,6 +38,8 @@ from langchain.schema.messages import (
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
@ -53,39 +55,6 @@ class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library"""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
class FunctionMessage(BaseMessage):
"""Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""Message Chunk for passing the result of executing a function back to a model."""
pass
def _create_retry_decorator(
llm: ChatLiteLLM,
run_manager: Optional[

@ -199,6 +199,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods ---
@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input

@ -28,6 +28,7 @@ from langchain.schema import (
)
from langchain.schema.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
A type of a prompt value that is built from messages.
"""
messages: List[BaseMessage]
messages: Sequence[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
@ -289,7 +290,14 @@ class ChatPromptValue(PromptValue):
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return self.messages
return list(self.messages)
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
class BaseChatPromptTemplate(BasePromptTemplate, ABC):

@ -13,8 +13,10 @@ from typing import (
Union,
)
from typing_extensions import TypeAlias
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
@ -70,6 +72,21 @@ class BaseLanguageModel(
Each of these has an equivalent asynchronous method.
"""
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
@abstractmethod
def generate_prompt(
self,

@ -1,10 +1,11 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from typing_extensions import Literal
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import Extra, Field
if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate
@ -69,10 +70,10 @@ class BaseMessage(Serializable):
additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information."""
@property
@abstractmethod
def type(self) -> str:
"""Type of the Message, used for serialization."""
type: str
class Config:
extra = Extra.allow
@classmethod
def is_lc_serializable(cls) -> bool:
@ -147,10 +148,10 @@ class HumanMessage(BaseMessage):
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
type: Literal["human"] = "human"
HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
@ -167,10 +168,10 @@ class AIMessage(BaseMessage):
conversation.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
type: Literal["ai"] = "ai"
AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk):
@ -199,10 +200,10 @@ class SystemMessage(BaseMessage):
of input messages.
"""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
type: Literal["system"] = "system"
SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
@ -217,10 +218,10 @@ class FunctionMessage(BaseMessage):
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
type: Literal["function"] = "function"
FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
@ -250,10 +251,10 @@ class ChatMessage(BaseMessage):
role: str
"""The speaker / role of the Message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
type: Literal["chat"] = "chat"
ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
@ -277,6 +278,9 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return super().__add__(other)
AnyMessage = Union[AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage]
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}

@ -14,8 +14,10 @@ from typing import (
Union,
)
from typing_extensions import get_args
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import AnyMessage, BaseMessage
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
@ -58,6 +60,16 @@ class BaseGenerationOutputParser(
):
"""Base class to parse the output of an LLM call."""
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:
@ -129,6 +141,22 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
return "boolean_output_parser"
""" # noqa: E501
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
@ -36,6 +36,20 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
return Union[StringPromptValue, ChatPromptValueConcrete]
@property
def input_schema(self) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput", **{k: (Any, None) for k in self.input_variables}
)
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(

@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait
from functools import partial
from itertools import tee
from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
@ -27,6 +28,8 @@ from typing import (
cast,
)
from typing_extensions import get_args
if TYPE_CHECKING:
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -37,7 +40,7 @@ if TYPE_CHECKING:
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import (
RunnableConfig,
acall_func_with_variable_args,
@ -55,6 +58,7 @@ from langchain.schema.runnable.utils import (
accepts_config,
accepts_run_manager,
gather_with_concurrency,
get_function_first_arg_dict_keys,
)
from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee
@ -66,6 +70,52 @@ class Runnable(Generic[Input, Output], ABC):
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
transformed."""
@property
def InputType(self) -> Type[Input]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
"Override the InputType property to specify the input type."
)
@property
def OutputType(self) -> Type[Output]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
return type_args[1]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
@property
def input_schema(self) -> Type[BaseModel]:
root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.__class__.__name__ + "Input", __root__=(root_type, None)
)
@property
def output_schema(self) -> Type[BaseModel]:
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.__class__.__name__ + "Output", __root__=(root_type, None)
)
def __or__(
self,
other: Union[
@ -849,6 +899,20 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config)
@ -953,6 +1017,22 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -1202,6 +1282,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.first.InputType
@property
def OutputType(self) -> Type[Output]:
return self.last.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.first.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema
def __or__(
self,
other: Union[
@ -1692,6 +1788,37 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Any:
for step in self.steps.values():
if step.InputType:
return step.InputType
return Any
@property
def input_schema(self) -> type[BaseModel]:
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
**{
k: (v.type_, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
},
)
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()},
)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
@ -1942,6 +2069,59 @@ class RunnableLambda(Runnable[Input, Output]):
f"Instead got an unsupported type: {type(func)}"
)
@property
def InputType(self) -> Any:
func = getattr(self, "func", None) or getattr(self, "afunc")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any
@property
def input_schema(self) -> Type[BaseModel]:
func = getattr(self, "func", None) or getattr(self, "afunc")
if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items
# on itemgetter objects, so we have to parse the repr
items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ")
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
# It's a dict, lol
return create_model(
"RunnableLambdaInput",
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
else:
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
"RunnableLambdaInput",
**{key: (Any, None) for key in dict_keys}, # type: ignore
)
return super().input_schema
@property
def OutputType(self) -> Any:
func = getattr(self, "func", None) or getattr(self, "afunc")
try:
sig = inspect.signature(func)
return (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"):
@ -2068,6 +2248,34 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
@property
def input_schema(self) -> type[BaseModel]:
return create_model(
"RunnableEachInput",
__root__=(
List[self.bound.input_schema], # type: ignore[name-defined]
None,
),
)
@property
def OutputType(self) -> type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
@property
def output_schema(self) -> type[BaseModel]:
return create_model(
"RunnableEachOutput",
__root__=(
List[self.bound.output_schema], # type: ignore[name-defined]
None,
),
)
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -2124,6 +2332,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class Config:
arbitrary_types_allowed = True
@property
def InputType(self) -> type[Input]:
return self.bound.InputType
@property
def OutputType(self) -> type[Output]:
return self.bound.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
@classmethod
def is_lc_serializable(cls) -> bool:
return True

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, AsyncIterator, Iterator, List, Optional
from typing import Any, AsyncIterator, Iterator, List, Optional, Type
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable
@ -20,6 +20,8 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
A runnable that passes through the input.
"""
input_type: Optional[Type[Input]] = None
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@ -28,6 +30,14 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
@property
def InputType(self) -> Any:
return self.input_type or Any
@property
def OutputType(self) -> Any:
return self.input_type or Any
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config)

@ -4,16 +4,16 @@ from typing import (
Any,
AsyncIterator,
Callable,
Generic,
Iterator,
List,
Mapping,
Optional,
TypedDict,
Union,
cast,
)
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import (
Input,
@ -43,21 +43,17 @@ class RouterInput(TypedDict):
input: Any
class RouterRunnable(
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
"""
A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable.
"""
runnables: Mapping[str, Runnable[Input, Output]]
runnables: Mapping[str, Runnable[Any, Output]]
def __init__(
self,
runnables: Mapping[
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
) -> None:
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}

@ -1,8 +1,11 @@
from __future__ import annotations
import ast
import asyncio
import inspect
import textwrap
from inspect import signature
from typing import Any, Callable, Coroutine, TypeVar, Union
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
@ -35,3 +38,61 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
return signature(callable).parameters.get("config") is not None
except ValueError:
return False
class IsLocalDict(ast.NodeVisitor):
def __init__(self, name: str, keys: Set[str]) -> None:
self.name = name
self.keys = keys
def visit_Subscript(self, node: ast.Subscript) -> Any:
if (
isinstance(node.ctx, ast.Load)
and isinstance(node.value, ast.Name)
and node.value.id == self.name
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
):
# we've found a subscript access on the name we're looking for
self.keys.add(node.slice.value)
def visit_Call(self, node: ast.Call) -> Any:
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.name
and node.func.attr == "get"
and len(node.args) in (1, 2)
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
):
# we've found a .get() call on the name we're looking for
self.keys.add(node.args[0].value)
class IsFunctionArgDict(ast.NodeVisitor):
def __init__(self) -> None:
self.keys: Set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
except (TypeError, OSError):
return None

@ -187,6 +187,14 @@ class ChildTool(BaseTool):
# --- Runnable ---
@property
def input_schema(self) -> Type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
def invoke(
self,
input: Union[str, Dict],

File diff suppressed because one or more lines are too long

@ -1,3 +1,4 @@
import sys
from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
@ -12,6 +13,8 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps
@ -43,6 +46,7 @@ from langchain.schema.runnable import (
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
class FakeTracer(BaseTracer):
@ -115,6 +119,412 @@ class FakeRetriever(BaseRetriever):
return [Document(page_content="foo"), Document(page_content="bar")]
def test_schemas(snapshot: SnapshotAssertion) -> None:
fake = FakeRunnable() # str -> int
assert fake.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
fake_bound = FakeRunnable().bind(a="b") # str -> int
assert fake_bound.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_bound.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
assert fake_w_fallbacks.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_w_fallbacks.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
def typed_lambda_impl(x: str) -> int:
return len(x)
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
assert typed_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "string",
}
assert typed_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"type": "integer",
}
async def typed_async_lambda_impl(x: str) -> int:
return len(x)
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
assert typed_async_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "string",
}
assert typed_async_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"type": "integer",
}
fake_ret = FakeRetriever() # str -> List[Document]
assert fake_ret.input_schema.schema() == {
"title": "FakeRetrieverInput",
"type": "string",
}
assert fake_ret.output_schema.schema() == {
"title": "FakeRetrieverOutput",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert fake_llm.input_schema.schema() == snapshot
assert fake_llm.output_schema.schema() == {
"title": "FakeListLLMOutput",
"type": "string",
}
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
assert fake_chat.input_schema.schema() == snapshot
assert fake_chat.output_schema.schema() == snapshot
prompt = PromptTemplate.from_template("Hello, {name}!")
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert prompt.output_schema.schema() == snapshot
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
assert prompt_mapper.input_schema.schema() == {
"definitions": {
"PromptInput": {
"properties": {"name": {"title": "Name"}},
"title": "PromptInput",
"type": "object",
}
},
"items": {"$ref": "#/definitions/PromptInput"},
"type": "array",
"title": "RunnableEachInput",
}
assert prompt_mapper.output_schema.schema() == snapshot
list_parser = CommaSeparatedListOutputParser()
assert list_parser.input_schema.schema() == snapshot
assert list_parser.output_schema.schema() == {
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
"items": {"type": "string"},
}
seq = prompt | fake_llm | list_parser
assert seq.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert seq.output_schema.schema() == {
"type": "array",
"items": {"type": "string"},
"title": "CommaSeparatedListOutputParserOutput",
}
router: Runnable = RouterRunnable({})
assert router.input_schema.schema() == {
"title": "RouterRunnableInput",
"$ref": "#/definitions/RouterInput",
"definitions": {
"RouterInput": {
"title": "RouterInput",
"type": "object",
"properties": {
"key": {"title": "Key", "type": "string"},
"input": {"title": "Input"},
},
"required": ["key", "input"],
}
},
}
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
seq_w_map: Runnable = (
prompt
| fake_llm
| {
"original": RunnablePassthrough(input_type=str),
"as_list": list_parser,
"length": typed_lambda_impl,
}
)
assert seq_w_map.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert seq_w_map.output_schema.schema() == {
"title": "RunnableMapOutput",
"type": "object",
"properties": {
"original": {"title": "Original", "type": "string"},
"length": {"title": "Length", "type": "integer"},
"as_list": {
"title": "As List",
"type": "array",
"items": {"type": "string"},
},
},
}
json_list_keys_tool = JsonListKeysTool(spec=JsonSpec(dict_={}))
assert json_list_keys_tool.input_schema.schema() == {
"title": "json_spec_list_keysSchema",
"type": "object",
"properties": {"tool_input": {"title": "Tool Input", "type": "string"}},
"required": ["tool_input"],
}
assert json_list_keys_tool.output_schema.schema() == {
"title": "JsonListKeysToolOutput"
}
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_lambda_schemas() -> None:
first_lambda = lambda x: x["hello"] # noqa: E731
assert RunnableLambda(first_lambda).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}},
}
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
assert RunnableLambda(
second_lambda, # type: ignore[arg-type]
).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
}
def get_value(input): # type: ignore[no-untyped-def]
return input["variable_name"]
assert RunnableLambda(get_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"variable_name": {"title": "Variable Name"}},
}
async def aget_value(input): # type: ignore[no-untyped-def]
return (input["variable_name"], input.get("another"))
assert RunnableLambda(aget_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {
"another": {"title": "Another"},
"variable_name": {"title": "Variable Name"},
},
}
async def aget_values(input): # type: ignore[no-untyped-def]
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
}
assert RunnableLambda(aget_values).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {
"variable_name": {"title": "Variable Name"},
"yo": {"title": "Yo"},
},
}
def test_schema_complex_seq() -> None:
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
prompt2 = ChatPromptTemplate.from_template(
"what country is the city {city} in? respond in {language}"
)
model = FakeListChatModel(responses=[""])
chain1 = prompt1 | model | StrOutputParser()
chain2: Runnable = (
{"city": chain1, "language": itemgetter("language")}
| prompt2
| model
| StrOutputParser()
)
assert chain2.input_schema.schema() == {
"title": "RunnableMapInput",
"type": "object",
"properties": {
"person": {"title": "Person"},
"language": {"title": "Language"},
},
}
assert chain2.output_schema.schema() == {
"title": "StrOutputParserOutput",
"type": "string",
}
def test_schema_chains() -> None:
model = FakeListChatModel(responses=[""])
stuff_chain = load_summarize_chain(model)
assert stuff_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert stuff_chain.output_schema.schema() == {
"title": "CombineDocumentsOutput",
"type": "object",
"properties": {"output_text": {"title": "Output Text", "type": "string"}},
}
mapreduce_chain = load_summarize_chain(
model, "map_reduce", return_intermediate_steps=True
)
assert mapreduce_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert mapreduce_chain.output_schema.schema() == {
"title": "MapReduceDocumentsOutput",
"type": "object",
"properties": {
"output_text": {"title": "Output Text", "type": "string"},
"intermediate_steps": {
"title": "Intermediate Steps",
"type": "array",
"items": {"type": "string"},
},
},
}
maprerank_chain = load_qa_chain(model, "map_rerank", metadata_keys=["hello"])
assert maprerank_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert maprerank_chain.output_schema.schema() == {
"title": "MapRerankOutput",
"type": "object",
"properties": {
"output_text": {"title": "Output Text", "type": "string"},
"hello": {"title": "Hello"},
},
}
@pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable()
@ -2160,6 +2570,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
assert isinstance(body, Runnable)
assert isinstance(runnable.default, Runnable)
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:

Loading…
Cancel
Save