mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add input/output schemas to runnables (#11063)
This adds `input_schema` and `output_schema` properties to all runnables, which are Pydantic models for the input and output types respectively. These are inferred from the structure of the Runnable as much as possible, the only manual typing needed is - optionally add type hints to lambdas (which get translated to input/output schemas) - optionally add type hint to RunnablePassthrough These schemas can then be used to create JSON Schema descriptions of input and output types, see the tests - [x] Ensure no InputType and OutputType in our classes use abstract base classes (replace with union of subclasses) - [x] Implement in BaseChain and LLMChain - [x] Implement in RunnableBranch - [x] Implement in RunnableBinding, RunnableMap, RunnablePassthrough, RunnableEach, RunnableRouter - [x] Implement in LLM, Prompt, Chat Model, Output Parser, Retriever - [x] Implement in RunnableLambda from function signature - [x] Implement in Tool <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
b05bb9e136
commit
cfa2203c62
@ -7,7 +7,7 @@ import warnings
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -22,7 +22,13 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field, root_validator, validator
|
from langchain.pydantic_v1 import (
|
||||||
|
BaseModel,
|
||||||
|
Field,
|
||||||
|
create_model,
|
||||||
|
root_validator,
|
||||||
|
validator,
|
||||||
|
)
|
||||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
|
|
||||||
@ -56,6 +62,20 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
chains and cannot return as rich of an output as `__call__`.
|
chains and cannot return as rich of an output as `__call__`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"ChainInput", **{k: (Any, None) for k in self.input_keys}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Dict[str, Any],
|
input: Dict[str, Any],
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Base interface for chains combining documents."""
|
"""Base interface for chains combining documents."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -9,7 +9,7 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||||
|
|
||||||
|
|
||||||
@ -28,6 +28,20 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
input_key: str = "input_documents" #: :meta private:
|
input_key: str = "input_documents" #: :meta private:
|
||||||
output_key: str = "output_text" #: :meta private:
|
output_key: str = "output_text" #: :meta private:
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return create_model(
|
||||||
|
"CombineDocumentsInput",
|
||||||
|
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return create_model(
|
||||||
|
"CombineDocumentsOutput",
|
||||||
|
**{self.output_key: (str, None)}, # type: ignore[call-overload]
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
@ -153,6 +167,17 @@ class AnalyzeDocumentChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return self.combine_docs_chain.output_keys
|
return self.combine_docs_chain.output_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return create_model(
|
||||||
|
"AnalyzeDocumentChain",
|
||||||
|
**{self.input_key: (str, None)}, # type: ignore[call-overload]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.combine_docs_chain.output_schema
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, str],
|
inputs: Dict[str, str],
|
||||||
|
@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|||||||
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||||
|
|
||||||
|
|
||||||
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||||
@ -98,6 +98,19 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
return_intermediate_steps: bool = False
|
return_intermediate_steps: bool = False
|
||||||
"""Return the results of the map steps in the output."""
|
"""Return the results of the map steps in the output."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> type[BaseModel]:
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
return create_model(
|
||||||
|
"MapReduceDocumentsOutput",
|
||||||
|
**{
|
||||||
|
self.output_key: (str, None),
|
||||||
|
"intermediate_steps": (List[str], None),
|
||||||
|
}, # type: ignore[call-overload]
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().output_schema
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
@ -9,7 +9,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.output_parsers.regex import RegexParser
|
from langchain.output_parsers.regex import RegexParser
|
||||||
from langchain.pydantic_v1 import Extra, root_validator
|
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
|
||||||
|
|
||||||
|
|
||||||
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
||||||
@ -77,6 +77,18 @@ class MapRerankDocumentsChain(BaseCombineDocumentsChain):
|
|||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> type[BaseModel]:
|
||||||
|
schema: Dict[str, Any] = {
|
||||||
|
self.output_key: (str, None),
|
||||||
|
}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
schema["intermediate_steps"] = (List[str], None)
|
||||||
|
if self.metadata_keys:
|
||||||
|
schema.update({key: (Any, None) for key in self.metadata_keys})
|
||||||
|
|
||||||
|
return create_model("MapRerankOutput", **schema)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
"""Expect input key.
|
"""Expect input key.
|
||||||
|
@ -11,6 +11,7 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,9 +38,14 @@ from langchain.schema import (
|
|||||||
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
|
HumanMessageChunk,
|
||||||
|
SystemMessageChunk,
|
||||||
)
|
)
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
@ -107,6 +113,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
|
|
||||||
# --- Runnable methods ---
|
# --- Runnable methods ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Any:
|
||||||
|
"""Get the input type for this runnable."""
|
||||||
|
return Union[
|
||||||
|
HumanMessageChunk,
|
||||||
|
AIMessageChunk,
|
||||||
|
ChatMessageChunk,
|
||||||
|
FunctionMessageChunk,
|
||||||
|
SystemMessageChunk,
|
||||||
|
]
|
||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
return input
|
return input
|
||||||
|
@ -38,6 +38,8 @@ from langchain.schema.messages import (
|
|||||||
BaseMessageChunk,
|
BaseMessageChunk,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatMessageChunk,
|
ChatMessageChunk,
|
||||||
|
FunctionMessage,
|
||||||
|
FunctionMessageChunk,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
HumanMessageChunk,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
@ -53,39 +55,6 @@ class ChatLiteLLMException(Exception):
|
|||||||
"""Error with the `LiteLLM I/O` library"""
|
"""Error with the `LiteLLM I/O` library"""
|
||||||
|
|
||||||
|
|
||||||
def _truncate_at_stop_tokens(
|
|
||||||
text: str,
|
|
||||||
stop: Optional[List[str]],
|
|
||||||
) -> str:
|
|
||||||
"""Truncates text at the earliest stop token found."""
|
|
||||||
if stop is None:
|
|
||||||
return text
|
|
||||||
|
|
||||||
for stop_token in stop:
|
|
||||||
stop_token_idx = text.find(stop_token)
|
|
||||||
if stop_token_idx != -1:
|
|
||||||
text = text[:stop_token_idx]
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessage(BaseMessage):
|
|
||||||
"""Message for passing the result of executing a function back to a model."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the function that was executed."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "function"
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
|
||||||
"""Message Chunk for passing the result of executing a function back to a model."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _create_retry_decorator(
|
def _create_retry_decorator(
|
||||||
llm: ChatLiteLLM,
|
llm: ChatLiteLLM,
|
||||||
run_manager: Optional[
|
run_manager: Optional[
|
||||||
|
@ -199,6 +199,11 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
|
|
||||||
# --- Runnable methods ---
|
# --- Runnable methods ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[str]:
|
||||||
|
"""Get the input type for this runnable."""
|
||||||
|
return str
|
||||||
|
|
||||||
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
||||||
if isinstance(input, PromptValue):
|
if isinstance(input, PromptValue):
|
||||||
return input
|
return input
|
||||||
|
@ -28,6 +28,7 @@ from langchain.schema import (
|
|||||||
)
|
)
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
@ -280,7 +281,7 @@ class ChatPromptValue(PromptValue):
|
|||||||
A type of a prompt value that is built from messages.
|
A type of a prompt value that is built from messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages: List[BaseMessage]
|
messages: Sequence[BaseMessage]
|
||||||
"""List of messages."""
|
"""List of messages."""
|
||||||
|
|
||||||
def to_string(self) -> str:
|
def to_string(self) -> str:
|
||||||
@ -289,7 +290,14 @@ class ChatPromptValue(PromptValue):
|
|||||||
|
|
||||||
def to_messages(self) -> List[BaseMessage]:
|
def to_messages(self) -> List[BaseMessage]:
|
||||||
"""Return prompt as a list of messages."""
|
"""Return prompt as a list of messages."""
|
||||||
return self.messages
|
return list(self.messages)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatPromptValueConcrete(ChatPromptValue):
|
||||||
|
"""Chat prompt value which explicitly lists out the message types it accepts.
|
||||||
|
For use in external schemas."""
|
||||||
|
|
||||||
|
messages: Sequence[AnyMessage]
|
||||||
|
|
||||||
|
|
||||||
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||||
|
@ -13,8 +13,10 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import AnyMessage, BaseMessage, get_buffer_string
|
||||||
from langchain.schema.output import LLMResult
|
from langchain.schema.output import LLMResult
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable
|
from langchain.schema.runnable import Runnable
|
||||||
@ -70,6 +72,21 @@ class BaseLanguageModel(
|
|||||||
Each of these has an equivalent asynchronous method.
|
Each of these has an equivalent asynchronous method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> TypeAlias:
|
||||||
|
"""Get the input type for this runnable."""
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
|
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||||
|
|
||||||
|
# This is a version of LanguageModelInput which replaces the abstract
|
||||||
|
# base class BaseMessage with a union of its subclasses, which makes
|
||||||
|
# for a much better schema.
|
||||||
|
return Union[
|
||||||
|
str,
|
||||||
|
Union[StringPromptValue, ChatPromptValueConcrete],
|
||||||
|
List[AnyMessage],
|
||||||
|
]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import abstractmethod
|
from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Union
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Sequence
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Extra, Field
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
@ -69,10 +70,10 @@ class BaseMessage(Serializable):
|
|||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
"""Any additional information."""
|
"""Any additional information."""
|
||||||
|
|
||||||
@property
|
type: str
|
||||||
@abstractmethod
|
|
||||||
def type(self) -> str:
|
class Config:
|
||||||
"""Type of the Message, used for serialization."""
|
extra = Extra.allow
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
@ -147,10 +148,10 @@ class HumanMessage(BaseMessage):
|
|||||||
conversation.
|
conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
type: Literal["human"] = "human"
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "human"
|
HumanMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
|
||||||
@ -167,10 +168,10 @@ class AIMessage(BaseMessage):
|
|||||||
conversation.
|
conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
type: Literal["ai"] = "ai"
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "ai"
|
AIMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||||
@ -199,10 +200,10 @@ class SystemMessage(BaseMessage):
|
|||||||
of input messages.
|
of input messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
type: Literal["system"] = "system"
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "system"
|
SystemMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
|
||||||
@ -217,10 +218,10 @@ class FunctionMessage(BaseMessage):
|
|||||||
name: str
|
name: str
|
||||||
"""The name of the function that was executed."""
|
"""The name of the function that was executed."""
|
||||||
|
|
||||||
@property
|
type: Literal["function"] = "function"
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "function"
|
FunctionMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||||
@ -250,10 +251,10 @@ class ChatMessage(BaseMessage):
|
|||||||
role: str
|
role: str
|
||||||
"""The speaker / role of the Message."""
|
"""The speaker / role of the Message."""
|
||||||
|
|
||||||
@property
|
type: Literal["chat"] = "chat"
|
||||||
def type(self) -> str:
|
|
||||||
"""Type of the message, used for serialization."""
|
|
||||||
return "chat"
|
ChatMessage.update_forward_refs()
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||||
@ -277,6 +278,9 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
|||||||
return super().__add__(other)
|
return super().__add__(other)
|
||||||
|
|
||||||
|
|
||||||
|
AnyMessage = Union[AIMessage, HumanMessage, ChatMessage, SystemMessage, FunctionMessage]
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
def _message_to_dict(message: BaseMessage) -> dict:
|
||||||
return {"type": message.type, "data": message.dict()}
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
@ -14,8 +14,10 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import AnyMessage, BaseMessage
|
||||||
from langchain.schema.output import ChatGeneration, Generation
|
from langchain.schema.output import ChatGeneration, Generation
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
from langchain.schema.runnable import Runnable, RunnableConfig
|
from langchain.schema.runnable import Runnable, RunnableConfig
|
||||||
@ -58,6 +60,16 @@ class BaseGenerationOutputParser(
|
|||||||
):
|
):
|
||||||
"""Base class to parse the output of an LLM call."""
|
"""Base class to parse the output of an LLM call."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> type[T]:
|
||||||
|
# even though mypy complains this isn't valid,
|
||||||
|
# it is good enough for pydantic to build the schema from
|
||||||
|
return T # type: ignore[misc]
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||||
) -> T:
|
) -> T:
|
||||||
@ -129,6 +141,22 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
|||||||
return "boolean_output_parser"
|
return "boolean_output_parser"
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
return Union[str, AnyMessage]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> type[T]:
|
||||||
|
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||||
|
type_args = get_args(cls)
|
||||||
|
if type_args and len(type_args) == 1:
|
||||||
|
return type_args[0]
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||||
|
"Override the OutputType property to specify the output type."
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
|
||||||
) -> T:
|
) -> T:
|
||||||
|
@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field, root_validator
|
from langchain.pydantic_v1 import BaseModel, Field, create_model, root_validator
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.output_parser import BaseOutputParser
|
from langchain.schema.output_parser import BaseOutputParser
|
||||||
from langchain.schema.prompt import PromptValue
|
from langchain.schema.prompt import PromptValue
|
||||||
@ -36,6 +36,20 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
|||||||
|
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Any:
|
||||||
|
from langchain.prompts.base import StringPromptValue
|
||||||
|
from langchain.prompts.chat import ChatPromptValueConcrete
|
||||||
|
|
||||||
|
return Union[StringPromptValue, ChatPromptValueConcrete]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> type[BaseModel]:
|
||||||
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"PromptInput", **{k: (Any, None) for k in self.input_variables}
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.format_prompt(
|
lambda inner_input: self.format_prompt(
|
||||||
|
@ -7,6 +7,7 @@ from abc import ABC, abstractmethod
|
|||||||
from concurrent.futures import FIRST_COMPLETED, wait
|
from concurrent.futures import FIRST_COMPLETED, wait
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import tee
|
from itertools import tee
|
||||||
|
from operator import itemgetter
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -27,6 +28,8 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import get_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
@ -37,7 +40,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import BaseModel, Field, create_model
|
||||||
from langchain.schema.runnable.config import (
|
from langchain.schema.runnable.config import (
|
||||||
RunnableConfig,
|
RunnableConfig,
|
||||||
acall_func_with_variable_args,
|
acall_func_with_variable_args,
|
||||||
@ -55,6 +58,7 @@ from langchain.schema.runnable.utils import (
|
|||||||
accepts_config,
|
accepts_config,
|
||||||
accepts_run_manager,
|
accepts_run_manager,
|
||||||
gather_with_concurrency,
|
gather_with_concurrency,
|
||||||
|
get_function_first_arg_dict_keys,
|
||||||
)
|
)
|
||||||
from langchain.utils.aiter import atee, py_anext
|
from langchain.utils.aiter import atee, py_anext
|
||||||
from langchain.utils.iter import safetee
|
from langchain.utils.iter import safetee
|
||||||
@ -66,6 +70,52 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
|
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
|
||||||
transformed."""
|
transformed."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||||
|
type_args = get_args(cls)
|
||||||
|
if type_args and len(type_args) == 2:
|
||||||
|
return type_args[0]
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
|
||||||
|
"Override the InputType property to specify the input type."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined]
|
||||||
|
type_args = get_args(cls)
|
||||||
|
if type_args and len(type_args) == 2:
|
||||||
|
return type_args[1]
|
||||||
|
|
||||||
|
raise TypeError(
|
||||||
|
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||||
|
"Override the OutputType property to specify the output type."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
root_type = self.InputType
|
||||||
|
|
||||||
|
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||||
|
return root_type
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.__class__.__name__ + "Input", __root__=(root_type, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
root_type = self.OutputType
|
||||||
|
|
||||||
|
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||||
|
return root_type
|
||||||
|
|
||||||
|
return create_model(
|
||||||
|
self.__class__.__name__ + "Output", __root__=(root_type, None)
|
||||||
|
)
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
other: Union[
|
other: Union[
|
||||||
@ -849,6 +899,20 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
|
|||||||
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
"""The namespace of a RunnableBranch is the namespace of its default branch."""
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> type[BaseModel]:
|
||||||
|
runnables = (
|
||||||
|
[self.default]
|
||||||
|
+ [r for _, r in self.branches]
|
||||||
|
+ [r for r, _ in self.branches]
|
||||||
|
)
|
||||||
|
|
||||||
|
for runnable in runnables:
|
||||||
|
if runnable.input_schema.schema().get("type") is not None:
|
||||||
|
return runnable.input_schema
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||||
"""First evaluates the condition, then delegate to true or false branch."""
|
"""First evaluates the condition, then delegate to true or false branch."""
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
@ -953,6 +1017,22 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
return self.runnable.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
return self.runnable.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.runnable.output_schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -1202,6 +1282,22 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Type[Input]:
|
||||||
|
return self.first.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Type[Output]:
|
||||||
|
return self.last.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.first.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.last.output_schema
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
other: Union[
|
other: Union[
|
||||||
@ -1692,6 +1788,37 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
for step in self.steps.values():
|
||||||
|
if step.InputType:
|
||||||
|
return step.InputType
|
||||||
|
|
||||||
|
return Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> type[BaseModel]:
|
||||||
|
if all(not s.input_schema.__custom_root_type__ for s in self.steps.values()):
|
||||||
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"RunnableMapInput",
|
||||||
|
**{
|
||||||
|
k: (v.type_, v.default)
|
||||||
|
for step in self.steps.values()
|
||||||
|
for k, v in step.input_schema.__fields__.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> type[BaseModel]:
|
||||||
|
# This is correct, but pydantic typings/mypy don't think so.
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"RunnableMapOutput",
|
||||||
|
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Input, config: Optional[RunnableConfig] = None
|
self, input: Input, config: Optional[RunnableConfig] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@ -1942,6 +2069,59 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
f"Instead got an unsupported type: {type(func)}"
|
f"Instead got an unsupported type: {type(func)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||||
|
try:
|
||||||
|
params = inspect.signature(func).parameters
|
||||||
|
first_param = next(iter(params.values()), None)
|
||||||
|
if first_param and first_param.annotation != inspect.Parameter.empty:
|
||||||
|
return first_param.annotation
|
||||||
|
else:
|
||||||
|
return Any
|
||||||
|
except ValueError:
|
||||||
|
return Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||||
|
|
||||||
|
if isinstance(func, itemgetter):
|
||||||
|
# This is terrible, but afaict it's not possible to access _items
|
||||||
|
# on itemgetter objects, so we have to parse the repr
|
||||||
|
items = str(func).replace("operator.itemgetter(", "")[:-1].split(", ")
|
||||||
|
if all(
|
||||||
|
item[0] == "'" and item[-1] == "'" and len(item) > 2 for item in items
|
||||||
|
):
|
||||||
|
# It's a dict, lol
|
||||||
|
return create_model(
|
||||||
|
"RunnableLambdaInput",
|
||||||
|
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
||||||
|
|
||||||
|
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||||
|
return create_model(
|
||||||
|
"RunnableLambdaInput",
|
||||||
|
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Any:
|
||||||
|
func = getattr(self, "func", None) or getattr(self, "afunc")
|
||||||
|
try:
|
||||||
|
sig = inspect.signature(func)
|
||||||
|
return (
|
||||||
|
sig.return_annotation
|
||||||
|
if sig.return_annotation != inspect.Signature.empty
|
||||||
|
else Any
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return Any
|
||||||
|
|
||||||
def __eq__(self, other: Any) -> bool:
|
def __eq__(self, other: Any) -> bool:
|
||||||
if isinstance(other, RunnableLambda):
|
if isinstance(other, RunnableLambda):
|
||||||
if hasattr(self, "func") and hasattr(other, "func"):
|
if hasattr(self, "func") and hasattr(other, "func"):
|
||||||
@ -2068,6 +2248,34 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
return List[self.bound.InputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> type[BaseModel]:
|
||||||
|
return create_model(
|
||||||
|
"RunnableEachInput",
|
||||||
|
__root__=(
|
||||||
|
List[self.bound.input_schema], # type: ignore[name-defined]
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> type[List[Output]]:
|
||||||
|
return List[self.bound.OutputType] # type: ignore[name-defined]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> type[BaseModel]:
|
||||||
|
return create_model(
|
||||||
|
"RunnableEachOutput",
|
||||||
|
__root__=(
|
||||||
|
List[self.bound.output_schema], # type: ignore[name-defined]
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -2124,6 +2332,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|||||||
class Config:
|
class Config:
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> type[Input]:
|
||||||
|
return self.bound.InputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> type[Output]:
|
||||||
|
return self.bound.OutputType
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.bound.input_schema
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_schema(self) -> Type[BaseModel]:
|
||||||
|
return self.bound.output_schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
from typing import Any, AsyncIterator, Iterator, List, Optional, Type
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.runnable.base import Input, Runnable
|
from langchain.schema.runnable.base import Input, Runnable
|
||||||
@ -20,6 +20,8 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
A runnable that passes through the input.
|
A runnable that passes through the input.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
input_type: Optional[Type[Input]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -28,6 +30,14 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
|||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def InputType(self) -> Any:
|
||||||
|
return self.input_type or Any
|
||||||
|
|
||||||
|
@property
|
||||||
|
def OutputType(self) -> Any:
|
||||||
|
return self.input_type or Any
|
||||||
|
|
||||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||||
return self._call_with_config(identity, input, config)
|
return self._call_with_config(identity, input, config)
|
||||||
|
|
||||||
|
@ -4,16 +4,16 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Generic,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
TypedDict,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema.runnable.base import (
|
from langchain.schema.runnable.base import (
|
||||||
Input,
|
Input,
|
||||||
@ -43,21 +43,17 @@ class RouterInput(TypedDict):
|
|||||||
input: Any
|
input: Any
|
||||||
|
|
||||||
|
|
||||||
class RouterRunnable(
|
class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
|
||||||
Serializable, Generic[Input, Output], Runnable[RouterInput, Output]
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
A runnable that routes to a set of runnables based on Input['key'].
|
A runnable that routes to a set of runnables based on Input['key'].
|
||||||
Returns the output of the selected runnable.
|
Returns the output of the selected runnable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
runnables: Mapping[str, Runnable[Input, Output]]
|
runnables: Mapping[str, Runnable[Any, Output]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runnables: Mapping[
|
runnables: Mapping[str, Union[Runnable[Any, Output], Callable[[Any], Output]]],
|
||||||
str, Union[Runnable[Input, Output], Callable[[Input], Output]]
|
|
||||||
],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, Callable, Coroutine, TypeVar, Union
|
from typing import Any, Callable, Coroutine, List, Optional, Set, TypeVar, Union
|
||||||
|
|
||||||
Input = TypeVar("Input")
|
Input = TypeVar("Input")
|
||||||
# Output type should implement __concat__, as eg str, list, dict do
|
# Output type should implement __concat__, as eg str, list, dict do
|
||||||
@ -35,3 +38,61 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
|||||||
return signature(callable).parameters.get("config") is not None
|
return signature(callable).parameters.get("config") is not None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class IsLocalDict(ast.NodeVisitor):
|
||||||
|
def __init__(self, name: str, keys: Set[str]) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.keys = keys
|
||||||
|
|
||||||
|
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
||||||
|
if (
|
||||||
|
isinstance(node.ctx, ast.Load)
|
||||||
|
and isinstance(node.value, ast.Name)
|
||||||
|
and node.value.id == self.name
|
||||||
|
and isinstance(node.slice, ast.Constant)
|
||||||
|
and isinstance(node.slice.value, str)
|
||||||
|
):
|
||||||
|
# we've found a subscript access on the name we're looking for
|
||||||
|
self.keys.add(node.slice.value)
|
||||||
|
|
||||||
|
def visit_Call(self, node: ast.Call) -> Any:
|
||||||
|
if (
|
||||||
|
isinstance(node.func, ast.Attribute)
|
||||||
|
and isinstance(node.func.value, ast.Name)
|
||||||
|
and node.func.value.id == self.name
|
||||||
|
and node.func.attr == "get"
|
||||||
|
and len(node.args) in (1, 2)
|
||||||
|
and isinstance(node.args[0], ast.Constant)
|
||||||
|
and isinstance(node.args[0].value, str)
|
||||||
|
):
|
||||||
|
# we've found a .get() call on the name we're looking for
|
||||||
|
self.keys.add(node.args[0].value)
|
||||||
|
|
||||||
|
|
||||||
|
class IsFunctionArgDict(ast.NodeVisitor):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.keys: Set[str] = set()
|
||||||
|
|
||||||
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
||||||
|
input_arg_name = node.args.args[0].arg
|
||||||
|
IsLocalDict(input_arg_name, self.keys).visit(node.body)
|
||||||
|
|
||||||
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
||||||
|
input_arg_name = node.args.args[0].arg
|
||||||
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
||||||
|
input_arg_name = node.args.args[0].arg
|
||||||
|
IsLocalDict(input_arg_name, self.keys).visit(node)
|
||||||
|
|
||||||
|
|
||||||
|
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
||||||
|
try:
|
||||||
|
code = inspect.getsource(func)
|
||||||
|
tree = ast.parse(textwrap.dedent(code))
|
||||||
|
visitor = IsFunctionArgDict()
|
||||||
|
visitor.visit(tree)
|
||||||
|
return list(visitor.keys) if visitor.keys else None
|
||||||
|
except (TypeError, OSError):
|
||||||
|
return None
|
||||||
|
@ -187,6 +187,14 @@ class ChildTool(BaseTool):
|
|||||||
|
|
||||||
# --- Runnable ---
|
# --- Runnable ---
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_schema(self) -> Type[BaseModel]:
|
||||||
|
"""The tool's input schema."""
|
||||||
|
if self.args_schema is not None:
|
||||||
|
return self.args_schema
|
||||||
|
else:
|
||||||
|
return create_schema_from_function(self.name, self._run)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, Dict],
|
input: Union[str, Dict],
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
from typing import Any, Dict, List, Optional, Sequence, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -12,6 +13,8 @@ from langchain.callbacks.tracers.base import BaseTracer
|
|||||||
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
from langchain.chains.question_answering import load_qa_chain
|
||||||
|
from langchain.chains.summarize import load_summarize_chain
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
from langchain.llms.fake import FakeListLLM, FakeStreamingListLLM
|
||||||
from langchain.load.dump import dumpd, dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
@ -43,6 +46,7 @@ from langchain.schema.runnable import (
|
|||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
|
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
|
||||||
|
|
||||||
|
|
||||||
class FakeTracer(BaseTracer):
|
class FakeTracer(BaseTracer):
|
||||||
@ -115,6 +119,412 @@ class FakeRetriever(BaseRetriever):
|
|||||||
return [Document(page_content="foo"), Document(page_content="bar")]
|
return [Document(page_content="foo"), Document(page_content="bar")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||||
|
fake = FakeRunnable() # str -> int
|
||||||
|
|
||||||
|
assert fake.input_schema.schema() == {
|
||||||
|
"title": "FakeRunnableInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert fake.output_schema.schema() == {
|
||||||
|
"title": "FakeRunnableOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_bound = FakeRunnable().bind(a="b") # str -> int
|
||||||
|
|
||||||
|
assert fake_bound.input_schema.schema() == {
|
||||||
|
"title": "FakeRunnableInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert fake_bound.output_schema.schema() == {
|
||||||
|
"title": "FakeRunnableOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_w_fallbacks = FakeRunnable().with_fallbacks((fake,)) # str -> int
|
||||||
|
|
||||||
|
assert fake_w_fallbacks.input_schema.schema() == {
|
||||||
|
"title": "FakeRunnableInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert fake_w_fallbacks.output_schema.schema() == {
|
||||||
|
"title": "FakeRunnableOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
def typed_lambda_impl(x: str) -> int:
|
||||||
|
return len(x)
|
||||||
|
|
||||||
|
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
|
||||||
|
|
||||||
|
assert typed_lambda.input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert typed_lambda.output_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def typed_async_lambda_impl(x: str) -> int:
|
||||||
|
return len(x)
|
||||||
|
|
||||||
|
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
|
||||||
|
|
||||||
|
assert typed_async_lambda.input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert typed_async_lambda.output_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaOutput",
|
||||||
|
"type": "integer",
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_ret = FakeRetriever() # str -> List[Document]
|
||||||
|
|
||||||
|
assert fake_ret.input_schema.schema() == {
|
||||||
|
"title": "FakeRetrieverInput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
assert fake_ret.output_schema.schema() == {
|
||||||
|
"title": "FakeRetrieverOutput",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/definitions/Document"},
|
||||||
|
"definitions": {
|
||||||
|
"Document": {
|
||||||
|
"title": "Document",
|
||||||
|
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"title": "Page Content", "type": "string"},
|
||||||
|
"metadata": {"title": "Metadata", "type": "object"},
|
||||||
|
},
|
||||||
|
"required": ["page_content"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||||
|
|
||||||
|
assert fake_llm.input_schema.schema() == snapshot
|
||||||
|
assert fake_llm.output_schema.schema() == {
|
||||||
|
"title": "FakeListLLMOutput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
fake_chat = FakeListChatModel(responses=["a"]) # str -> List[List[str]]
|
||||||
|
|
||||||
|
assert fake_chat.input_schema.schema() == snapshot
|
||||||
|
assert fake_chat.output_schema.schema() == snapshot
|
||||||
|
|
||||||
|
prompt = PromptTemplate.from_template("Hello, {name}!")
|
||||||
|
|
||||||
|
assert prompt.input_schema.schema() == {
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"title": "Name"}},
|
||||||
|
}
|
||||||
|
assert prompt.output_schema.schema() == snapshot
|
||||||
|
|
||||||
|
prompt_mapper = PromptTemplate.from_template("Hello, {name}!").map()
|
||||||
|
|
||||||
|
assert prompt_mapper.input_schema.schema() == {
|
||||||
|
"definitions": {
|
||||||
|
"PromptInput": {
|
||||||
|
"properties": {"name": {"title": "Name"}},
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"items": {"$ref": "#/definitions/PromptInput"},
|
||||||
|
"type": "array",
|
||||||
|
"title": "RunnableEachInput",
|
||||||
|
}
|
||||||
|
assert prompt_mapper.output_schema.schema() == snapshot
|
||||||
|
|
||||||
|
list_parser = CommaSeparatedListOutputParser()
|
||||||
|
|
||||||
|
assert list_parser.input_schema.schema() == snapshot
|
||||||
|
assert list_parser.output_schema.schema() == {
|
||||||
|
"title": "CommaSeparatedListOutputParserOutput",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
}
|
||||||
|
|
||||||
|
seq = prompt | fake_llm | list_parser
|
||||||
|
|
||||||
|
assert seq.input_schema.schema() == {
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"title": "Name"}},
|
||||||
|
}
|
||||||
|
assert seq.output_schema.schema() == {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"title": "CommaSeparatedListOutputParserOutput",
|
||||||
|
}
|
||||||
|
|
||||||
|
router: Runnable = RouterRunnable({})
|
||||||
|
|
||||||
|
assert router.input_schema.schema() == {
|
||||||
|
"title": "RouterRunnableInput",
|
||||||
|
"$ref": "#/definitions/RouterInput",
|
||||||
|
"definitions": {
|
||||||
|
"RouterInput": {
|
||||||
|
"title": "RouterInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {"title": "Key", "type": "string"},
|
||||||
|
"input": {"title": "Input"},
|
||||||
|
},
|
||||||
|
"required": ["key", "input"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert router.output_schema.schema() == {"title": "RouterRunnableOutput"}
|
||||||
|
|
||||||
|
seq_w_map: Runnable = (
|
||||||
|
prompt
|
||||||
|
| fake_llm
|
||||||
|
| {
|
||||||
|
"original": RunnablePassthrough(input_type=str),
|
||||||
|
"as_list": list_parser,
|
||||||
|
"length": typed_lambda_impl,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seq_w_map.input_schema.schema() == {
|
||||||
|
"title": "PromptInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"title": "Name"}},
|
||||||
|
}
|
||||||
|
assert seq_w_map.output_schema.schema() == {
|
||||||
|
"title": "RunnableMapOutput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"original": {"title": "Original", "type": "string"},
|
||||||
|
"length": {"title": "Length", "type": "integer"},
|
||||||
|
"as_list": {
|
||||||
|
"title": "As List",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
json_list_keys_tool = JsonListKeysTool(spec=JsonSpec(dict_={}))
|
||||||
|
|
||||||
|
assert json_list_keys_tool.input_schema.schema() == {
|
||||||
|
"title": "json_spec_list_keysSchema",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"tool_input": {"title": "Tool Input", "type": "string"}},
|
||||||
|
"required": ["tool_input"],
|
||||||
|
}
|
||||||
|
assert json_list_keys_tool.output_schema.schema() == {
|
||||||
|
"title": "JsonListKeysToolOutput"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
|
)
|
||||||
|
def test_lambda_schemas() -> None:
|
||||||
|
first_lambda = lambda x: x["hello"] # noqa: E731
|
||||||
|
assert RunnableLambda(first_lambda).input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"hello": {"title": "Hello"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
second_lambda = lambda x, y: (x["hello"], x["bye"], y["bah"]) # noqa: E731
|
||||||
|
assert RunnableLambda(
|
||||||
|
second_lambda, # type: ignore[arg-type]
|
||||||
|
).input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"hello": {"title": "Hello"}, "bye": {"title": "Bye"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_value(input): # type: ignore[no-untyped-def]
|
||||||
|
return input["variable_name"]
|
||||||
|
|
||||||
|
assert RunnableLambda(get_value).input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"variable_name": {"title": "Variable Name"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aget_value(input): # type: ignore[no-untyped-def]
|
||||||
|
return (input["variable_name"], input.get("another"))
|
||||||
|
|
||||||
|
assert RunnableLambda(aget_value).input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"another": {"title": "Another"},
|
||||||
|
"variable_name": {"title": "Variable Name"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def aget_values(input): # type: ignore[no-untyped-def]
|
||||||
|
return {
|
||||||
|
"hello": input["variable_name"],
|
||||||
|
"bye": input["variable_name"],
|
||||||
|
"byebye": input["yo"],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert RunnableLambda(aget_values).input_schema.schema() == {
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"variable_name": {"title": "Variable Name"},
|
||||||
|
"yo": {"title": "Yo"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_complex_seq() -> None:
|
||||||
|
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
|
||||||
|
prompt2 = ChatPromptTemplate.from_template(
|
||||||
|
"what country is the city {city} in? respond in {language}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FakeListChatModel(responses=[""])
|
||||||
|
|
||||||
|
chain1 = prompt1 | model | StrOutputParser()
|
||||||
|
|
||||||
|
chain2: Runnable = (
|
||||||
|
{"city": chain1, "language": itemgetter("language")}
|
||||||
|
| prompt2
|
||||||
|
| model
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chain2.input_schema.schema() == {
|
||||||
|
"title": "RunnableMapInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"person": {"title": "Person"},
|
||||||
|
"language": {"title": "Language"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert chain2.output_schema.schema() == {
|
||||||
|
"title": "StrOutputParserOutput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_chains() -> None:
|
||||||
|
model = FakeListChatModel(responses=[""])
|
||||||
|
|
||||||
|
stuff_chain = load_summarize_chain(model)
|
||||||
|
|
||||||
|
assert stuff_chain.input_schema.schema() == {
|
||||||
|
"title": "CombineDocumentsInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_documents": {
|
||||||
|
"title": "Input Documents",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/definitions/Document"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"definitions": {
|
||||||
|
"Document": {
|
||||||
|
"title": "Document",
|
||||||
|
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"title": "Page Content", "type": "string"},
|
||||||
|
"metadata": {"title": "Metadata", "type": "object"},
|
||||||
|
},
|
||||||
|
"required": ["page_content"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert stuff_chain.output_schema.schema() == {
|
||||||
|
"title": "CombineDocumentsOutput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"output_text": {"title": "Output Text", "type": "string"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
mapreduce_chain = load_summarize_chain(
|
||||||
|
model, "map_reduce", return_intermediate_steps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mapreduce_chain.input_schema.schema() == {
|
||||||
|
"title": "CombineDocumentsInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_documents": {
|
||||||
|
"title": "Input Documents",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/definitions/Document"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"definitions": {
|
||||||
|
"Document": {
|
||||||
|
"title": "Document",
|
||||||
|
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"title": "Page Content", "type": "string"},
|
||||||
|
"metadata": {"title": "Metadata", "type": "object"},
|
||||||
|
},
|
||||||
|
"required": ["page_content"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert mapreduce_chain.output_schema.schema() == {
|
||||||
|
"title": "MapReduceDocumentsOutput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"output_text": {"title": "Output Text", "type": "string"},
|
||||||
|
"intermediate_steps": {
|
||||||
|
"title": "Intermediate Steps",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
maprerank_chain = load_qa_chain(model, "map_rerank", metadata_keys=["hello"])
|
||||||
|
|
||||||
|
assert maprerank_chain.input_schema.schema() == {
|
||||||
|
"title": "CombineDocumentsInput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"input_documents": {
|
||||||
|
"title": "Input Documents",
|
||||||
|
"type": "array",
|
||||||
|
"items": {"$ref": "#/definitions/Document"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"definitions": {
|
||||||
|
"Document": {
|
||||||
|
"title": "Document",
|
||||||
|
"description": "Class for storing a piece of text and associated metadata.", # noqa: E501
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"page_content": {"title": "Page Content", "type": "string"},
|
||||||
|
"metadata": {"title": "Metadata", "type": "object"},
|
||||||
|
},
|
||||||
|
"required": ["page_content"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert maprerank_chain.output_schema.schema() == {
|
||||||
|
"title": "MapRerankOutput",
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"output_text": {"title": "Output Text", "type": "string"},
|
||||||
|
"hello": {"title": "Hello"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_with_config(mocker: MockerFixture) -> None:
|
async def test_with_config(mocker: MockerFixture) -> None:
|
||||||
fake = FakeRunnable()
|
fake = FakeRunnable()
|
||||||
@ -2160,6 +2570,7 @@ def test_runnable_branch_init_coercion(branches: Sequence[Any]) -> None:
|
|||||||
assert isinstance(body, Runnable)
|
assert isinstance(body, Runnable)
|
||||||
|
|
||||||
assert isinstance(runnable.default, Runnable)
|
assert isinstance(runnable.default, Runnable)
|
||||||
|
assert runnable.input_schema.schema() == {"title": "RunnableBranchInput"}
|
||||||
|
|
||||||
|
|
||||||
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
def test_runnable_branch_invoke_call_counts(mocker: MockerFixture) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user