manual mapping (#14422)

pull/14476/head
Harrison Chase 6 months ago committed by GitHub
parent c24f277b7c
commit f5befe3b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
from __future__ import annotations
import json
from typing import Any, Literal, Sequence, Union
from typing import Any, List, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
@ -40,6 +40,11 @@ class AgentAction(Serializable):
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"]
@property
def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this action."""
@ -98,6 +103,11 @@ class AgentFinish(Serializable):
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"]
@property
def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation."""

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Literal
from typing import List, Literal
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field
@ -21,3 +21,8 @@ class Document(Serializable):
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]

@ -3,6 +3,7 @@ import json
import os
from typing import Any, Dict, List, Optional
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
from langchain_core.load.serializable import Serializable
DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
@ -62,8 +63,21 @@ class Reviver:
if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace))
cls = getattr(mod, name)
# Get the importable path
key = tuple(namespace + [name])
if key not in SERIALIZABLE_MAPPING:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{key}"
)
import_path = SERIALIZABLE_MAPPING[key]
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
# The class must be a subclass of Serializable.
if not issubclass(cls, Serializable):

@ -0,0 +1,478 @@
# First value is the value that it is serialized as
# Second value is the path to load it from
SERIALIZABLE_MAPPING = {
("langchain", "schema", "messages", "AIMessage"): (
"langchain_core",
"messages",
"ai",
"AIMessage",
),
("langchain", "schema", "messages", "AIMessageChunk"): (
"langchain_core",
"messages",
"ai",
"AIMessageChunk",
),
("langchain", "schema", "messages", "BaseMessage"): (
"langchain_core",
"messages",
"base",
"BaseMessage",
),
("langchain", "schema", "messages", "BaseMessageChunk"): (
"langchain_core",
"messages",
"base",
"BaseMessageChunk",
),
("langchain", "schema", "messages", "ChatMessage"): (
"langchain_core",
"messages",
"chat",
"ChatMessage",
),
("langchain", "schema", "messages", "FunctionMessage"): (
"langchain_core",
"messages",
"function",
"FunctionMessage",
),
("langchain", "schema", "messages", "HumanMessage"): (
"langchain_core",
"messages",
"human",
"HumanMessage",
),
("langchain", "schema", "messages", "SystemMessage"): (
"langchain_core",
"messages",
"system",
"SystemMessage",
),
("langchain", "schema", "messages", "ToolMessage"): (
"langchain_core",
"messages",
"tool",
"ToolMessage",
),
("langchain", "schema", "agent", "AgentAction"): (
"langchain_core",
"agents",
"AgentAction",
),
("langchain", "schema", "agent", "AgentFinish"): (
"langchain_core",
"agents",
"AgentFinish",
),
("langchain", "schema", "prompt_template", "BasePromptTemplate"): (
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
),
("langchain", "chains", "llm", "LLMChain"): (
"langchain",
"chains",
"llm",
"LLMChain",
),
("langchain", "prompts", "prompt", "PromptTemplate"): (
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
),
("langchain", "prompts", "chat", "MessagesPlaceholder"): (
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
),
("langchain", "llms", "openai", "OpenAI"): (
"langchain",
"llms",
"openai",
"OpenAI",
),
("langchain", "prompts", "chat", "ChatPromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
),
("langchain", "prompts", "chat", "HumanMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
),
("langchain", "prompts", "chat", "SystemMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
),
("langchain", "schema", "agent", "AgentActionMessageLog"): (
"langchain_core",
"agents",
"AgentActionMessageLog",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
),
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
),
("langchain", "schema", "output", "ChatGeneration"): (
"langchain_core",
"outputs",
"chat_generation",
"ChatGeneration",
),
("langchain", "schema", "output", "Generation"): (
"langchain_core",
"outputs",
"generation",
"Generation",
),
("langchain", "schema", "document", "Document"): (
"langchain_core",
"documents",
"base",
"Document",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
"fix",
"OutputFixingParser",
),
("langchain", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
),
("langchain", "output_parsers", "regex", "RegexParser"): (
"langchain",
"output_parsers",
"regex",
"RegexParser",
),
("langchain", "schema", "runnable", "DynamicRunnable"): (
"langchain_core",
"runnables",
"configurable",
"DynamicRunnable",
),
("langchain", "schema", "prompt", "PromptValue"): (
"langchain_core",
"prompt_values",
"PromptValue",
),
("langchain", "schema", "runnable", "RunnableBinding"): (
"langchain_core",
"runnables",
"base",
"RunnableBinding",
),
("langchain", "schema", "runnable", "RunnableBranch"): (
"langchain_core",
"runnables",
"branch",
"RunnableBranch",
),
("langchain", "schema", "runnable", "RunnableWithFallbacks"): (
"langchain_core",
"runnables",
"fallbacks",
"RunnableWithFallbacks",
),
("langchain", "schema", "output_parser", "StrOutputParser"): (
"langchain_core",
"output_parsers",
"string",
"StrOutputParser",
),
("langchain", "chat_models", "openai", "ChatOpenAI"): (
"langchain",
"chat_models",
"openai",
"ChatOpenAI",
),
("langchain", "output_parsers", "list", "CommaSeparatedListOutputParser"): (
"langchain_core",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
),
("langchain", "schema", "runnable", "RunnableParallel"): (
"langchain_core",
"runnables",
"base",
"RunnableParallel",
),
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"): (
"langchain",
"chat_models",
"azure_openai",
"AzureChatOpenAI",
),
("langchain", "chat_models", "bedrock", "BedrockChat"): (
"langchain",
"chat_models",
"bedrock",
"BedrockChat",
),
("langchain", "chat_models", "anthropic", "ChatAnthropic"): (
"langchain",
"chat_models",
"anthropic",
"ChatAnthropic",
),
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain",
"chat_models",
"fireworks",
"ChatFireworks",
),
("langchain", "chat_models", "google_palm", "ChatGooglePalm"): (
"langchain",
"chat_models",
"google_palm",
"ChatGooglePalm",
),
("langchain", "chat_models", "vertexai", "ChatVertexAI"): (
"langchain",
"chat_models",
"vertexai",
"ChatVertexAI",
),
("langchain", "schema", "output", "ChatGenerationChunk"): (
"langchain_core",
"outputs",
"chat_generation",
"ChatGenerationChunk",
),
("langchain", "schema", "messages", "ChatMessageChunk"): (
"langchain_core",
"messages",
"chat",
"ChatMessageChunk",
),
("langchain", "schema", "messages", "HumanMessageChunk"): (
"langchain_core",
"messages",
"human",
"HumanMessageChunk",
),
("langchain", "schema", "messages", "FunctionMessageChunk"): (
"langchain_core",
"messages",
"function",
"FunctionMessageChunk",
),
("langchain", "schema", "messages", "SystemMessageChunk"): (
"langchain_core",
"messages",
"system",
"SystemMessageChunk",
),
("langchain", "schema", "messages", "ToolMessageChunk"): (
"langchain_core",
"messages",
"tool",
"ToolMessageChunk",
),
("langchain", "schema", "output", "GenerationChunk"): (
"langchain_core",
"outputs",
"generation",
"GenerationChunk",
),
("langchain", "llms", "openai", "BaseOpenAI"): (
"langchain",
"llms",
"openai",
"BaseOpenAI",
),
("langchain", "llms", "bedrock", "Bedrock"): (
"langchain",
"llms",
"bedrock",
"Bedrock",
),
("langchain", "llms", "fireworks", "Fireworks"): (
"langchain",
"llms",
"fireworks",
"Fireworks",
),
("langchain", "llms", "google_palm", "GooglePalm"): (
"langchain",
"llms",
"google_palm",
"GooglePalm",
),
("langchain", "llms", "openai", "AzureOpenAI"): (
"langchain",
"llms",
"openai",
"AzureOpenAI",
),
("langchain", "llms", "replicate", "Replicate"): (
"langchain",
"llms",
"replicate",
"Replicate",
),
("langchain", "llms", "vertexai", "VertexAI"): (
"langchain",
"llms",
"vertexai",
"VertexAI",
),
("langchain", "output_parsers", "combining", "CombiningOutputParser"): (
"langchain",
"output_parsers",
"combining",
"CombiningOutputParser",
),
("langchain", "schema", "prompt_template", "BaseChatPromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
),
("langchain", "prompts", "chat", "ChatMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
),
("langchain", "prompts", "few_shot_with_templates", "FewShotPromptWithTemplates"): (
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
),
("langchain", "prompts", "pipeline", "PipelinePromptTemplate"): (
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
),
("langchain", "prompts", "base", "StringPromptTemplate"): (
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
),
("langchain", "prompts", "base", "StringPromptValue"): (
"langchain_core",
"prompt_values",
"StringPromptValue",
),
("langchain", "prompts", "chat", "BaseStringMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
),
("langchain", "prompts", "chat", "ChatPromptValue"): (
"langchain_core",
"prompt_values",
"ChatPromptValue",
),
("langchain", "prompts", "chat", "ChatPromptValueConcrete"): (
"langchain_core",
"prompt_values",
"ChatPromptValueConcrete",
),
("langchain", "schema", "runnable", "HubRunnable"): (
"langchain",
"runnables",
"hub",
"HubRunnable",
),
("langchain", "schema", "runnable", "RunnableBindingBase"): (
"langchain_core",
"runnables",
"base",
"RunnableBindingBase",
),
("langchain", "schema", "runnable", "OpenAIFunctionsRouter"): (
"langchain",
"runnables",
"openai_functions",
"OpenAIFunctionsRouter",
),
("langchain", "schema", "runnable", "RouterRunnable"): (
"langchain_core",
"runnables",
"router",
"RouterRunnable",
),
("langchain", "schema", "runnable", "RunnablePassthrough"): (
"langchain_core",
"runnables",
"passthrough",
"RunnablePassthrough",
),
("langchain", "schema", "runnable", "RunnableSequence"): (
"langchain_core",
"runnables",
"base",
"RunnableSequence",
),
("langchain", "schema", "runnable", "RunnableEach"): (
"langchain_core",
"runnables",
"base",
"RunnableEach",
),
("langchain", "schema", "runnable", "RunnableEachBase"): (
"langchain_core",
"runnables",
"base",
"RunnableEachBase",
),
("langchain", "schema", "runnable", "RunnableConfigurableAlternatives"): (
"langchain_core",
"runnables",
"configurable",
"RunnableConfigurableAlternatives",
),
("langchain", "schema", "runnable", "RunnableConfigurableFields"): (
"langchain_core",
"runnables",
"configurable",
"RunnableConfigurableFields",
),
("langchain", "schema", "runnable", "RunnableWithMessageHistory"): (
"langchain_core",
"runnables",
"history",
"RunnableWithMessageHistory",
),
("langchain", "schema", "runnable", "RunnableAssign"): (
"langchain_core",
"runnables",
"passthrough",
"RunnableAssign",
),
("langchain", "schema", "runnable", "RunnableRetry"): (
"langchain_core",
"runnables",
"retry",
"RunnableRetry",
),
}

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any, List, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -17,6 +17,11 @@ class AIMessage(BaseMessage):
type: Literal["ai"] = "ai"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
AIMessage.update_forward_refs()
@ -29,6 +34,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:

@ -31,6 +31,11 @@ class BaseMessage(Serializable):
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate
@ -68,6 +73,11 @@ def merge_content(
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any, List, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -15,6 +15,11 @@ class ChatMessage(BaseMessage):
type: Literal["chat"] = "chat"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
ChatMessage.update_forward_refs()
@ -27,6 +32,11 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk):
if self.role != other.role:

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any, List, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -15,6 +15,11 @@ class FunctionMessage(BaseMessage):
type: Literal["function"] = "function"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
FunctionMessage.update_forward_refs()
@ -27,6 +32,11 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk):
if self.name != other.name:

@ -1,4 +1,4 @@
from typing import Literal
from typing import List, Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -13,6 +13,11 @@ class HumanMessage(BaseMessage):
type: Literal["human"] = "human"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
HumanMessage.update_forward_refs()
@ -24,3 +29,8 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@ -1,4 +1,4 @@
from typing import Literal
from typing import List, Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -10,6 +10,11 @@ class SystemMessage(BaseMessage):
type: Literal["system"] = "system"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
SystemMessage.update_forward_refs()
@ -21,3 +26,8 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the
# non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@ -1,4 +1,4 @@
from typing import Any, Literal
from typing import Any, List, Literal
from langchain_core.messages.base import (
BaseMessage,
@ -15,6 +15,11 @@ class ToolMessage(BaseMessage):
type: Literal["tool"] = "tool"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
ToolMessage.update_forward_refs()
@ -27,6 +32,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
# non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id:

@ -26,6 +26,11 @@ class CommaSeparatedListOutputParser(ListOutputParser):
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "output_parsers", "list"]
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "

@ -1,3 +1,5 @@
from typing import List
from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -9,6 +11,11 @@ class StrOutputParser(BaseTransformOutputParser[str]):
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output_parser"]
@property
def _type(self) -> str:
"""Return the output parser type for serialization."""

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Literal
from typing import Any, Dict, List, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
@ -27,6 +27,11 @@ class ChatGeneration(Generation):
raise ValueError("Error while initializing ChatGeneration") from e
return values
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
@ -41,6 +46,11 @@ class ChatGenerationChunk(ChatGeneration):
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable
@ -24,10 +24,20 @@ class Generation(Serializable):
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (

@ -24,6 +24,11 @@ class PromptValue(Serializable, ABC):
"""Return whether this class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "prompt"]
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@ -40,6 +45,11 @@ class StringPromptValue(PromptValue):
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"]
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
@ -66,6 +76,11 @@ class ChatPromptValue(PromptValue):
"""Return prompt as a list of messages."""
return list(self.messages)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
@ -74,3 +89,8 @@ class ChatPromptValueConcrete(ChatPromptValue):
messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

@ -44,6 +44,11 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
default_factory=dict
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "prompt_template"]
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""

@ -43,6 +43,11 @@ class BaseMessagePromptTemplate(Serializable, ABC):
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
@ -82,6 +87,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str
"""Name of variable to use as messages."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs)
@ -132,6 +142,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
@ -221,6 +236,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str
"""Role of the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@ -239,6 +259,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@ -255,6 +280,11 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@ -273,6 +303,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
This is a message that is not sent to the user.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@ -368,6 +403,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.

@ -42,6 +42,11 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
validate_template: bool = False
"""Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"]
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided."""

@ -28,6 +28,11 @@ class PipelinePromptTemplate(BasePromptTemplate):
pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"]
@root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict:
"""Get input variables."""

@ -54,6 +54,11 @@ class PromptTemplate(StringPromptTemplate):
"template_format": self.template_format,
}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "prompt"]
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

@ -151,6 +151,11 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"]
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))

@ -1349,6 +1349,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
last: Runnable[Any, Output]
"""The last runnable in the sequence."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def steps(self) -> List[Runnable[Any, Any]]:
"""All the runnables that make up the sequence in order."""
@ -1358,10 +1363,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
class Config:
arbitrary_types_allowed = True
@ -1939,7 +1940,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
@ -2705,7 +2707,8 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _invoke(
self,
@ -2746,6 +2749,11 @@ class RunnableEach(RunnableEachBase[Input, Output]):
with each element of the input sequence.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
@ -2910,7 +2918,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = merge_configs(self.config, *configs)
@ -3086,6 +3095,11 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot`
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""Bind additional kwargs to a Runnable, returning a new Runnable.

@ -132,8 +132,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None

@ -53,7 +53,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Type[Input]:
@ -217,6 +218,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, AnyConfigurableField]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
@ -318,6 +324,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
the alternative named "gpt3" becomes "model==gpt3/temperature"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
with _enums_for_spec_lock:

@ -125,7 +125,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def runnables(self) -> Iterator[Runnable[Input, Output]]:

@ -86,6 +86,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def __init__(
self,
runnable: Runnable[

@ -167,7 +167,8 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def InputType(self) -> Any:
@ -312,7 +313,8 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_input_schema(
self, config: Optional[RunnableConfig] = None

@ -114,6 +114,11 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
max_attempt_number: int = 3
"""The maximum number of attempts to retry the runnable."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property
def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict()

@ -77,7 +77,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None

File diff suppressed because one or more lines are too long

@ -2029,7 +2029,7 @@ async def test_prompt_with_llm(
):
del op["value"]["id"]
assert stream_log == [
expected = [
RunLogPatch(
{
"op": "replace",
@ -2113,6 +2113,7 @@ async def test_prompt_with_llm(
{"op": "replace", "path": "/final_output", "value": "foo"},
),
]
assert stream_log == expected
@freeze_time("2023-01-01")

@ -105,6 +105,11 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import os
import warnings
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict, List, Union
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@ -94,6 +94,11 @@ class AzureChatOpenAI(ChatOpenAI):
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

@ -50,6 +50,11 @@ class BedrockChat(BaseChatModel, BedrockBase):
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "bedrock"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

@ -101,6 +101,11 @@ class ChatFireworks(BaseChatModel):
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""

@ -256,6 +256,11 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "google_palm"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k."""

@ -160,6 +160,11 @@ class ChatOpenAI(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

@ -127,6 +127,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""

@ -357,6 +357,11 @@ class Bedrock(LLM, BedrockBase):
"""Return whether this model can be serialized by Langchain."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "bedrock"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}

@ -51,6 +51,11 @@ class Fireworks(BaseLLM):
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "fireworks"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""

@ -75,6 +75,11 @@ class GooglePalm(BaseLLM, BaseModel):
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "google_palm"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists."""

@ -149,6 +149,11 @@ class BaseOpenAI(BaseLLM):
def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@property
def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {}
@ -736,6 +741,11 @@ class OpenAI(BaseOpenAI):
openai = OpenAI(model_name="text-davinci-003")
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@property
def _invocation_params(self) -> Dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params}
@ -794,6 +804,11 @@ class AzureOpenAI(BaseOpenAI):
infer if it is a base_url or azure_endpoint and update accordingly.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""

@ -70,6 +70,11 @@ class Replicate(LLM):
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "replicate"]
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""

@ -224,6 +224,11 @@ class VertexAI(_VertexAICommon, BaseLLM):
def is_lc_serializable(self) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "vertexai"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""

@ -9,12 +9,12 @@ from langchain_core.pydantic_v1 import root_validator
class CombiningOutputParser(BaseOutputParser):
"""Combine multiple output parsers into one."""
parsers: List[BaseOutputParser]
@classmethod
def is_lc_serializable(cls) -> bool:
return True
parsers: List[BaseOutputParser]
@root_validator()
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the parsers."""

@ -97,7 +97,7 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"langchain",
"prompts",
"prompt",
"PromptTemplate"
@ -152,7 +152,7 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
@ -166,7 +166,7 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
@ -176,7 +176,7 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"langchain",
"prompts",
"prompt",
"PromptTemplate"
@ -236,7 +236,7 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain_core",
"langchain",
"prompts",
"prompt",
"PromptTemplate"

@ -0,0 +1,55 @@
import importlib
import pkgutil
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
def import_all_modules(package_name: str) -> dict:
package = importlib.import_module(package_name)
classes: dict = {}
for attribute_name in dir(package):
attribute = getattr(package, attribute_name)
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
if (
isinstance(attribute.is_lc_serializable(), bool) # type: ignore
and attribute.is_lc_serializable() # type: ignore
):
key = tuple(attribute.lc_id()) # type: ignore
value = tuple(attribute.__module__.split(".") + [attribute.__name__])
if key in classes and classes[key] != value:
raise ValueError
classes[key] = value
if hasattr(package, "__path__"):
for loader, module_name, is_pkg in pkgutil.walk_packages(
package.__path__, package_name + "."
):
if module_name not in (
"langchain.chains.llm_bash",
"langchain.chains.llm_symbolic_math",
"langchain.tools.python",
"langchain.vectorstores._pgvector_data_models",
):
importlib.import_module(module_name)
new_classes = import_all_modules(module_name)
for k, v in new_classes.items():
if k in classes and classes[k] != v:
raise ValueError
classes[k] = v
return classes
def test_serializable_mapping() -> None:
serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules)
assert missing == set()
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
assert extra == set()
for k, import_path in serializable_modules.items():
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
assert list(k) == cls.lc_id()
Loading…
Cancel
Save