Add input/output schemas to runnables (#11063)

This adds `input_schema` and `output_schema` properties to all
runnables, which are Pydantic models for the input and output types
respectively. These are inferred from the structure of the Runnable as
much as possible, the only manual typing needed is
- optionally add type hints to lambdas (which get translated to
input/output schemas)
- optionally add type hint to RunnablePassthrough

These schemas can then be used to create JSON Schema descriptions of
input and output types, see the tests

- [x] Ensure no InputType and OutputType in our classes use abstract
base classes (replace with union of subclasses)
- [x] Implement in BaseChain and LLMChain
- [x] Implement in RunnableBranch
- [x] Implement in RunnableBinding, RunnableMap, RunnablePassthrough,
RunnableEach, RunnableRouter
- [x] Implement in LLM, Prompt, Chat Model, Output Parser, Retriever
- [x] Implement in RunnableLambda from function signature
- [x] Implement in Tool

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-09-28 11:05:15 +01:00 committed by GitHub
parent b05bb9e136
commit cfa2203c62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 2211 additions and 86 deletions

View File

@ -7,7 +7,7 @@ import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Type, Union
import yaml import yaml
@ -22,7 +22,13 @@ from langchain.callbacks.manager import (
) )
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator, validator from langchain.pydantic_v1 import (
BaseModel,
Field,
create_model,
root_validator,
validator,
)
from langchain.schema import RUN_KEY, BaseMemory, RunInfo from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import Runnable, RunnableConfig
@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`. chains and cannot return as rich of an output as `__call__`.
""" """
@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)
@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
)
def invoke( def invoke(
self, self,
input: Dict[str, Any], input: Dict[str, Any],

View File

@ -1,7 +1,7 @@
"""Base interface for chains combining documents.""" """Base interface for chains combining documents."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Type
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private: input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private: output_key: str = "output_text" #: :meta private:
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
"""Expect input key. """Expect input key.
@ -153,6 +167,17 @@ class AnalyzeDocumentChain(Chain):
""" """
return self.combine_docs_chain.output_keys return self.combine_docs_chain.output_keys
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)
@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema
def _call( def _call(
self, self,
inputs: Dict[str, str], inputs: Dict[str, str],

View File

@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
class MapReduceDocumentsChain(BaseCombineDocumentsChain): class MapReduceDocumentsChain(BaseCombineDocumentsChain):
@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False return_intermediate_steps: bool = False
"""Return the results of the map steps in the output.""" """Return the results of the map steps in the output."""
@property
def output_schema(self) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
**{
self.output_key: (str, None),
"intermediate_steps": (List[str], None),
}, # type: ignore[call-overload]
)
return super().output_schema
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Expect input key. """Expect input key.

View File

@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import Extra, root_validator from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
class MapRerankDocumentsChain(BaseCombineDocumentsChain): class MapRerankDocumentsChain(BaseCombineDocumentsChain):
@ -77,6 +77,18 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
extra = Extra.forbid extra = Extra.forbid
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def output_schema(self) -> type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
if self.return_intermediate_steps:
schema["intermediate_steps"] = (List[str], None)
if self.metadata_keys:
schema.update({key: (Any, None) for key in self.metadata_keys})
return create_model("MapRerankOutput", **schema)
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
"""Expect input key. """Expect input key.

View File

@ -11,6 +11,7 @@ from typing import (
List, List,
Optional, Optional,
Sequence, Sequence,
Union,
cast, cast,
) )
@ -37,9 +38,14 @@ from langchain.schema import (
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessageChunk,
) )
from langchain.schema.output import ChatGenerationChunk from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
@ -107,6 +113,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
# --- Runnable methods --- # --- Runnable methods ---
@property
def OutputType(self) -> Any:
"""Get the input type for this runnable."""
return Union[
HumanMessageChunk,
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
SystemMessageChunk,
]
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(input, PromptValue):
return input return input

View File

@ -38,6 +38,8 @@ from langchain.schema.messages import (
BaseMessageChunk, BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk, ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk, HumanMessageChunk,
SystemMessage, SystemMessage,
@ -53,39 +55,6 @@ class ChatLiteLLMException(Exception):
"""Error with the `LiteLLM I/O` library""" """Error with the `LiteLLM I/O` library"""
def _truncate_at_stop_tokens(
text: str,
stop: Optional[List[str]],
) -> str:
"""Truncates text at the earliest stop token found."""
if stop is None:
return text
for stop_token in stop:
stop_token_idx = text.find(stop_token)
if stop_token_idx != -1:
text = text[:stop_token_idx]
return text
class FunctionMessage(BaseMessage):
"""Message for passing the result of executing a function back to a model."""
name: str
"""The name of the function that was executed."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function"
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""Message Chunk for passing the result of executing a function back to a model."""
pass
def _create_retry_decorator( def _create_retry_decorator(
llm: ChatLiteLLM, llm: ChatLiteLLM,
run_manager: Optional[ run_manager: Optional[

View File

@ -199,6 +199,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
# --- Runnable methods --- # --- Runnable methods ---
@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str
def _convert_input(self, input: LanguageModelInput) -> PromptValue: def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue): if isinstance(input, PromptValue):
return input return input

View File

@ -28,6 +28,7 @@ from langchain.schema import (
) )
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
AnyMessage,
BaseMessage, BaseMessage,
ChatMessage, ChatMessage,
HumanMessage, HumanMessage,
@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
A type of a prompt value that is built from messages. A type of a prompt value that is built from messages.
""" """
messages: List[BaseMessage] messages: Sequence[BaseMessage]
"""List of messages.""" """List of messages."""
def to_string(self) -> str: def to_string(self) -> str:
@ -289,7 +290,14 @@ class ChatPromptValue(PromptValue):
def to_messages(self) -> List[BaseMessage]: def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages.""" """Return prompt as a list of messages."""
return self.messages return list(self.messages)
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
class BaseChatPromptTemplate(BasePromptTemplate, ABC): class BaseChatPromptTemplate(BasePromptTemplate, ABC):

View File

@ -13,8 +13,10 @@ from typing import (
Union, Union,
) )
from typing_extensions import TypeAlias
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable from langchain.schema.runnable import Runnable
@ -70,6 +72,21 @@ class BaseLanguageModel(
Each of these has an equivalent asynchronous method. Each of these has an equivalent asynchronous method.
""" """
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return Union[
str,
Union[StringPromptValue, ChatPromptValueConcrete],
List[AnyMessage],
]
@abstractmethod @abstractmethod
def generate_prompt( def generate_prompt(
self, self,

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from abc import abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
from typing_extensions import Literal
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Extra, Field
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
@ -69,10 +70,10 @@ class BaseMessage(Serializable):
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
"""Any additional information.""" """Any additional information."""
@property type: str
@abstractmethod
def type(self) -> str: class Config:
"""Type of the Message, used for serialization.""" extra = Extra.allow
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
@ -147,10 +148,10 @@ class HumanMessage(BaseMessage):
conversation. conversation.
""" """
@property type: Literal["human"] = "human"
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human" HumanMessage.update_forward_refs()
class HumanMessageChunk(HumanMessage, BaseMessageChunk): class HumanMessageChunk(HumanMessage, BaseMessageChunk):
@ -167,10 +168,10 @@ class AIMessage(BaseMessage):
conversation. conversation.
""" """
@property type: Literal["ai"] = "ai"
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai" AIMessage.update_forward_refs()
class AIMessageChunk(AIMessage, BaseMessageChunk): class AIMessageChunk(AIMessage, BaseMessageChunk):
@ -199,10 +200,10 @@ class SystemMessage(BaseMessage):
of input messages. of input messages.
""" """
@property type: Literal["system"] = "system"
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system" SystemMessage.update_forward_refs()
class SystemMessageChunk(SystemMessage, BaseMessageChunk): class SystemMessageChunk(SystemMessage, BaseMessageChunk):
@ -217,10 +218,10 @@ class FunctionMessage(BaseMessage):
name: str name: str
"""The name of the function that was executed.""" """The name of the function that was executed."""
@property type: Literal["function"] = "function"
def type(self) -> str:
"""Type of the message, used for serialization."""
return "function" FunctionMessage.update_forward_refs()
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk): class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
@ -250,10 +251,10 @@ class ChatMessage(BaseMessage):
role: str role: str
"""The speaker / role of the Message.""" """The speaker / role of the Message."""
@property type: Literal["chat"] = "chat"
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat" ChatMessage.update_forward_refs()
class ChatMessageChunk(ChatMessage, BaseMessageChunk): class ChatMessageChunk(ChatMessage, BaseMessageChunk):
@ -277,6 +278,9 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return super().__add__(other) return super().__add__(other)
AnyMessage = Union[AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage]
def _message_to_dict(message: BaseMessage) -> dict: def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()} return {"type": message.type, "data": message.dict()}

View File

@ -14,8 +14,10 @@ from typing import (
Union, Union,
) )
from typing_extensions import get_args
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage from langchain.schema.messages import AnyMessage, BaseMessage
from langchain.schema.output import ChatGeneration, Generation from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig from langchain.schema.runnable import Runnable, RunnableConfig
@ -58,6 +60,16 @@ class BaseGenerationOutputParser(
): ):
"""Base class to parse the output of an LLM call.""" """Base class to parse the output of an LLM call."""
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
# even though mypy complains this isn't valid,
# it is good enough for pydantic to build the schema from
return T # type: ignore[misc]
def invoke( def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T: ) -> T:
@ -129,6 +141,22 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
return "boolean_output_parser" return "boolean_output_parser"
""" # noqa: E501 """ # noqa: E501
@property
def InputType(self) -> Any:
return Union[str, AnyMessage]
@property
def OutputType(self) -> type[T]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 1:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
def invoke( def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T: ) -> T:

View File

@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml import yaml
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field, root_validator from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue from langchain.schema.prompt import PromptValue
@ -36,6 +36,20 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def OutputType(self) -> Any:
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValueConcrete
return Union[StringPromptValue, ChatPromptValueConcrete]
@property
def input_schema(self) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput", **{k: (Any, None) for k in self.input_variables}
)
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue: def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.format_prompt( lambda inner_input: self.format_prompt(

View File

@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
from concurrent.futures import FIRST_COMPLETED, wait from concurrent.futures import FIRST_COMPLETED, wait
from functools import partial from functools import partial
from itertools import tee from itertools import tee
from operator import itemgetter
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -27,6 +28,8 @@ from typing import (
cast, cast,
) )
from typing_extensions import get_args
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
@ -37,7 +40,7 @@ if TYPE_CHECKING:
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import ( from langchain.schema.runnable.config import (
RunnableConfig, RunnableConfig,
acall_func_with_variable_args, acall_func_with_variable_args,
@ -55,6 +58,7 @@ from langchain.schema.runnable.utils import (
accepts_config, accepts_config,
accepts_run_manager, accepts_run_manager,
gather_with_concurrency, gather_with_concurrency,
get_function_first_arg_dict_keys,
) )
from langchain.utils.aiter import atee, py_anext from langchain.utils.aiter import atee, py_anext
from langchain.utils.iter import safetee from langchain.utils.iter import safetee
@ -66,6 +70,52 @@ class Runnable(Generic[Input, Output], ABC):
"""A Runnable is a unit of work that can be invoked, batched, streamed, or """A Runnable is a unit of work that can be invoked, batched, streamed, or
transformed.""" transformed."""
@property
def InputType(self) -> Type[Input]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
"Override the InputType property to specify the input type."
)
@property
def OutputType(self) -> Type[Output]:
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
type_args = get_args(cls)
if type_args and len(type_args) == 2:
return type_args[1]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
@property
def input_schema(self) -> Type[BaseModel]:
root_type = self.InputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.__class__.__name__ + "Input", __root__=(root_type, None)
)
@property
def output_schema(self) -> Type[BaseModel]:
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.__class__.__name__ + "Output", __root__=(root_type, None)
)
def __or__( def __or__(
self, self,
other: Union[ other: Union[
@ -849,6 +899,20 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
"""The namespace of a RunnableBranch is the namespace of its default branch.""" """The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
@property
def input_schema(self) -> type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)
for runnable in runnables:
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema
return super().input_schema
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
"""First evaluates the condition, then delegate to true or false branch.""" """First evaluates the condition, then delegate to true or false branch."""
config = ensure_config(config) config = ensure_config(config)
@ -953,6 +1017,22 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.runnable.InputType
@property
def OutputType(self) -> Type[Output]:
return self.runnable.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.runnable.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.runnable.output_schema
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@ -1202,6 +1282,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def InputType(self) -> Type[Input]:
return self.first.InputType
@property
def OutputType(self) -> Type[Output]:
return self.last.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.first.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema
def __or__( def __or__(
self, self,
other: Union[ other: Union[
@ -1692,6 +1788,37 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def InputType(self) -> Any:
for step in self.steps.values():
if step.InputType:
return step.InputType
return Any
@property
def input_schema(self) -> type[BaseModel]:
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
**{
k: (v.type_, v.default)
for step in self.steps.values()
for k, v in step.input_schema.__fields__.items()
},
)
return super().input_schema
@property
def output_schema(self) -> type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapOutput",
**{k: (v.OutputType, None) for k, v in self.steps.items()},
)
def invoke( def invoke(
self, input: Input, config: Optional[RunnableConfig] = None self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -1942,6 +2069,59 @@ class RunnableLambda(Runnable[Input, Output]):
f"Instead got an unsupported type: {type(func)}" f"Instead got an unsupported type: {type(func)}"
) )
@property
def InputType(self) -> Any:
func = getattr(self, "func", None) or getattr(self, "afunc")
try:
params = inspect.signature(func).parameters
first_param = next(iter(params.values()), None)
if first_param and first_param.annotation != inspect.Parameter.empty:
return first_param.annotation
else:
return Any
except ValueError:
return Any
@property
def input_schema(self) -> Type[BaseModel]:
func = getattr(self, "func", None) or getattr(self, "afunc")
if isinstance(func, itemgetter):
# This is terrible, but afaict it's not possible to access _items
# on itemgetter objects, so we have to parse the repr
items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ")
if all(
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
):
# It's a dict, lol
return create_model(
"RunnableLambdaInput",
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
)
else:
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
"RunnableLambdaInput",
**{key: (Any, None) for key in dict_keys}, # type: ignore
)
return super().input_schema
@property
def OutputType(self) -> Any:
func = getattr(self, "func", None) or getattr(self, "afunc")
try:
sig = inspect.signature(func)
return (
sig.return_annotation
if sig.return_annotation != inspect.Signature.empty
else Any
)
except ValueError:
return Any
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda): if isinstance(other, RunnableLambda):
if hasattr(self, "func") and hasattr(other, "func"): if hasattr(self, "func") and hasattr(other, "func"):
@ -2068,6 +2248,34 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]
@property
def input_schema(self) -> type[BaseModel]:
return create_model(
"RunnableEachInput",
__root__=(
List[self.bound.input_schema], # type: ignore[name-defined]
None,
),
)
@property
def OutputType(self) -> type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]
@property
def output_schema(self) -> type[BaseModel]:
return create_model(
"RunnableEachOutput",
__root__=(
List[self.bound.output_schema], # type: ignore[name-defined]
None,
),
)
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@ -2124,6 +2332,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@property
def InputType(self) -> type[Input]:
return self.bound.InputType
@property
def OutputType(self) -> type[Output]:
return self.bound.OutputType
@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, AsyncIterator, Iterator, List, Optional from typing import Any, AsyncIterator, Iterator, List, Optional, Type
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable from langchain.schema.runnable.base import Input, Runnable
@ -20,6 +20,8 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
A runnable that passes through the input. A runnable that passes through the input.
""" """
input_type: Optional[Type[Input]] = None
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@ -28,6 +30,14 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] return cls.__module__.split(".")[:-1]
@property
def InputType(self) -> Any:
return self.input_type or Any
@property
def OutputType(self) -> Any:
return self.input_type or Any
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input: def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(identity, input, config) return self._call_with_config(identity, input, config)

View File

@ -4,16 +4,16 @@ from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
Callable, Callable,
Generic,
Iterator, Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
TypedDict,
Union, Union,
cast, cast,
) )
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import ( from langchain.schema.runnable.base import (
Input, Input,
@ -43,21 +43,17 @@ class RouterInput(TypedDict):
input: Any input: Any
class RouterRunnable( class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
):
""" """
A runnable that routes to a set of runnables based on Input['key']. A runnable that routes to a set of runnables based on Input['key'].
Returns the output of the selected runnable. Returns the output of the selected runnable.
""" """
runnables: Mapping[str, Runnable[Input, Output]] runnables: Mapping[str, Runnable[Any, Output]]
def __init__( def __init__(
self, self,
runnables: Mapping[ runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
],
) -> None: ) -> None:
super().__init__( super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}

View File

@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
import ast
import asyncio import asyncio
import inspect
import textwrap
from inspect import signature from inspect import signature
from typing import Any, Callable, Coroutine, TypeVar, Union from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
Input = TypeVar("Input") Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do # Output type should implement __concat__, as eg str, list, dict do
@ -35,3 +38,61 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
return signature(callable).parameters.get("config") is not None return signature(callable).parameters.get("config") is not None
except ValueError: except ValueError:
return False return False
class IsLocalDict(ast.NodeVisitor):
def __init__(self, name: str, keys: Set[str]) -> None:
self.name = name
self.keys = keys
def visit_Subscript(self, node: ast.Subscript) -> Any:
if (
isinstance(node.ctx, ast.Load)
and isinstance(node.value, ast.Name)
and node.value.id == self.name
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
):
# we've found a subscript access on the name we're looking for
self.keys.add(node.slice.value)
def visit_Call(self, node: ast.Call) -> Any:
if (
isinstance(node.func, ast.Attribute)
and isinstance(node.func.value, ast.Name)
and node.func.value.id == self.name
and node.func.attr == "get"
and len(node.args) in (1, 2)
and isinstance(node.args[0], ast.Constant)
and isinstance(node.args[0].value, str)
):
# we've found a .get() call on the name we're looking for
self.keys.add(node.args[0].value)
class IsFunctionArgDict(ast.NodeVisitor):
def __init__(self) -> None:
self.keys: Set[str] = set()
def visit_Lambda(self, node: ast.Lambda) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node.body)
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
input_arg_name = node.args.args[0].arg
IsLocalDict(input_arg_name, self.keys).visit(node)
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
try:
code = inspect.getsource(func)
tree = ast.parse(textwrap.dedent(code))
visitor = IsFunctionArgDict()
visitor.visit(tree)
return list(visitor.keys) if visitor.keys else None
except (TypeError, OSError):
return None

View File

@ -187,6 +187,14 @@ class ChildTool(BaseTool):
# --- Runnable --- # --- Runnable ---
@property
def input_schema(self) -> Type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema
else:
return create_schema_from_function(self.name, self._run)
def invoke( def invoke(
self, self,
input: Union[str, Dict], input: Union[str, Dict],

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,4 @@
import sys
from operator import itemgetter from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union, cast from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID from uuid import UUID
@ -12,6 +13,8 @@ from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
from langchain.load.dump import dumpd, dumps from langchain.load.dump import dumpd, dumps
@ -43,6 +46,7 @@ from langchain.schema.runnable import (
RunnableSequence, RunnableSequence,
RunnableWithFallbacks, RunnableWithFallbacks,
) )
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
class FakeTracer(BaseTracer): class FakeTracer(BaseTracer):
@ -115,6 +119,412 @@ class FakeRetriever(BaseRetriever):
return [Document(page_content="foo"), Document(page_content="bar")] return [Document(page_content="foo"), Document(page_content="bar")]
def test_schemas(snapshot: SnapshotAssertion) -> None:
fake = FakeRunnable() # str -> int
assert fake.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
fake_bound = FakeRunnable().bind(a="b") # str -> int
assert fake_bound.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_bound.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
assert fake_w_fallbacks.input_schema.schema() == {
"title": "FakeRunnableInput",
"type": "string",
}
assert fake_w_fallbacks.output_schema.schema() == {
"title": "FakeRunnableOutput",
"type": "integer",
}
def typed_lambda_impl(x: str) -> int:
return len(x)
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
assert typed_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "string",
}
assert typed_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"type": "integer",
}
async def typed_async_lambda_impl(x: str) -> int:
return len(x)
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
assert typed_async_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "string",
}
assert typed_async_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"type": "integer",
}
fake_ret = FakeRetriever() # str -> List[Document]
assert fake_ret.input_schema.schema() == {
"title": "FakeRetrieverInput",
"type": "string",
}
assert fake_ret.output_schema.schema() == {
"title": "FakeRetrieverOutput",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
assert fake_llm.input_schema.schema() == snapshot
assert fake_llm.output_schema.schema() == {
"title": "FakeListLLMOutput",
"type": "string",
}
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
assert fake_chat.input_schema.schema() == snapshot
assert fake_chat.output_schema.schema() == snapshot
prompt = PromptTemplate.from_template("Hello, {name}!")
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert prompt.output_schema.schema() == snapshot
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
assert prompt_mapper.input_schema.schema() == {
"definitions": {
"PromptInput": {
"properties": {"name": {"title": "Name"}},
"title": "PromptInput",
"type": "object",
}
},
"items": {"$ref": "#/definitions/PromptInput"},
"type": "array",
"title": "RunnableEachInput",
}
assert prompt_mapper.output_schema.schema() == snapshot
list_parser = CommaSeparatedListOutputParser()
assert list_parser.input_schema.schema() == snapshot
assert list_parser.output_schema.schema() == {
"title": "CommaSeparatedListOutputParserOutput",
"type": "array",
"items": {"type": "string"},
}
seq = prompt | fake_llm | list_parser
assert seq.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert seq.output_schema.schema() == {
"type": "array",
"items": {"type": "string"},
"title": "CommaSeparatedListOutputParserOutput",
}
router: Runnable = RouterRunnable({})
assert router.input_schema.schema() == {
"title": "RouterRunnableInput",
"$ref": "#/definitions/RouterInput",
"definitions": {
"RouterInput": {
"title": "RouterInput",
"type": "object",
"properties": {
"key": {"title": "Key", "type": "string"},
"input": {"title": "Input"},
},
"required": ["key", "input"],
}
},
}
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
seq_w_map: Runnable = (
prompt
| fake_llm
| {
"original": RunnablePassthrough(input_type=str),
"as_list": list_parser,
"length": typed_lambda_impl,
}
)
assert seq_w_map.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"name": {"title": "Name"}},
}
assert seq_w_map.output_schema.schema() == {
"title": "RunnableMapOutput",
"type": "object",
"properties": {
"original": {"title": "Original", "type": "string"},
"length": {"title": "Length", "type": "integer"},
"as_list": {
"title": "As List",
"type": "array",
"items": {"type": "string"},
},
},
}
json_list_keys_tool = JsonListKeysTool(spec=JsonSpec(dict_={}))
assert json_list_keys_tool.input_schema.schema() == {
"title": "json_spec_list_keysSchema",
"type": "object",
"properties": {"tool_input": {"title": "Tool Input", "type": "string"}},
"required": ["tool_input"],
}
assert json_list_keys_tool.output_schema.schema() == {
"title": "JsonListKeysToolOutput"
}
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_lambda_schemas() -> None:
first_lambda = lambda x: x["hello"] # noqa: E731
assert RunnableLambda(first_lambda).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}},
}
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
assert RunnableLambda(
second_lambda, # type: ignore[arg-type]
).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
}
def get_value(input): # type: ignore[no-untyped-def]
return input["variable_name"]
assert RunnableLambda(get_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {"variable_name": {"title": "Variable Name"}},
}
async def aget_value(input): # type: ignore[no-untyped-def]
return (input["variable_name"], input.get("another"))
assert RunnableLambda(aget_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {
"another": {"title": "Another"},
"variable_name": {"title": "Variable Name"},
},
}
async def aget_values(input): # type: ignore[no-untyped-def]
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
}
assert RunnableLambda(aget_values).input_schema.schema() == {
"title": "RunnableLambdaInput",
"type": "object",
"properties": {
"variable_name": {"title": "Variable Name"},
"yo": {"title": "Yo"},
},
}
def test_schema_complex_seq() -> None:
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
prompt2 = ChatPromptTemplate.from_template(
"what country is the city {city} in? respond in {language}"
)
model = FakeListChatModel(responses=[""])
chain1 = prompt1 | model | StrOutputParser()
chain2: Runnable = (
{"city": chain1, "language": itemgetter("language")}
| prompt2
| model
| StrOutputParser()
)
assert chain2.input_schema.schema() == {
"title": "RunnableMapInput",
"type": "object",
"properties": {
"person": {"title": "Person"},
"language": {"title": "Language"},
},
}
assert chain2.output_schema.schema() == {
"title": "StrOutputParserOutput",
"type": "string",
}
def test_schema_chains() -> None:
model = FakeListChatModel(responses=[""])
stuff_chain = load_summarize_chain(model)
assert stuff_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert stuff_chain.output_schema.schema() == {
"title": "CombineDocumentsOutput",
"type": "object",
"properties": {"output_text": {"title": "Output Text", "type": "string"}},
}
mapreduce_chain = load_summarize_chain(
model, "map_reduce", return_intermediate_steps=True
)
assert mapreduce_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert mapreduce_chain.output_schema.schema() == {
"title": "MapReduceDocumentsOutput",
"type": "object",
"properties": {
"output_text": {"title": "Output Text", "type": "string"},
"intermediate_steps": {
"title": "Intermediate Steps",
"type": "array",
"items": {"type": "string"},
},
},
}
maprerank_chain = load_qa_chain(model, "map_rerank", metadata_keys=["hello"])
assert maprerank_chain.input_schema.schema() == {
"title": "CombineDocumentsInput",
"type": "object",
"properties": {
"input_documents": {
"title": "Input Documents",
"type": "array",
"items": {"$ref": "#/definitions/Document"},
}
},
"definitions": {
"Document": {
"title": "Document",
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
"type": "object",
"properties": {
"page_content": {"title": "Page Content", "type": "string"},
"metadata": {"title": "Metadata", "type": "object"},
},
"required": ["page_content"],
}
},
}
assert maprerank_chain.output_schema.schema() == {
"title": "MapRerankOutput",
"type": "object",
"properties": {
"output_text": {"title": "Output Text", "type": "string"},
"hello": {"title": "Hello"},
},
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_with_config(mocker: MockerFixture) -> None: async def test_with_config(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
@ -2160,6 +2570,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
assert isinstance(body, Runnable) assert isinstance(body, Runnable)
assert isinstance(runnable.default, Runnable) assert isinstance(runnable.default, Runnable)
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None: def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None: