manual mapping (#14422)

pull/14476/head
Harrison Chase 7 months ago committed by GitHub
parent c24f277b7c
commit f5befe3b89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import Any, Literal, Sequence, Union from typing import Any, List, Literal, Sequence, Union
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (
@ -40,6 +40,11 @@ class AgentAction(Serializable):
"""Return whether or not the class is serializable.""" """Return whether or not the class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"]
@property @property
def messages(self) -> Sequence[BaseMessage]: def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this action.""" """Return the messages that correspond to this action."""
@ -98,6 +103,11 @@ class AgentFinish(Serializable):
"""Return whether or not the class is serializable.""" """Return whether or not the class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "agent"]
@property @property
def messages(self) -> Sequence[BaseMessage]: def messages(self) -> Sequence[BaseMessage]:
"""Return the messages that correspond to this observation.""" """Return the messages that correspond to this observation."""

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Literal from typing import List, Literal
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Field from langchain_core.pydantic_v1 import Field
@ -21,3 +21,8 @@ class Document(Serializable):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.""" """Return whether this class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "document"]

@ -3,6 +3,7 @@ import json
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
DEFAULT_NAMESPACES = ["langchain", "langchain_core"] DEFAULT_NAMESPACES = ["langchain", "langchain_core"]
@ -62,8 +63,21 @@ class Reviver:
if len(namespace) == 1 and namespace[0] == "langchain": if len(namespace) == 1 and namespace[0] == "langchain":
raise ValueError(f"Invalid namespace: {value}") raise ValueError(f"Invalid namespace: {value}")
mod = importlib.import_module(".".join(namespace)) # Get the importable path
cls = getattr(mod, name) key = tuple(namespace + [name])
if key not in SERIALIZABLE_MAPPING:
raise ValueError(
"Trying to deserialize something that cannot "
"be deserialized in current version of langchain-core: "
f"{key}"
)
import_path = SERIALIZABLE_MAPPING[key]
# Split into module and name
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
# The class must be a subclass of Serializable. # The class must be a subclass of Serializable.
if not issubclass(cls, Serializable): if not issubclass(cls, Serializable):

@ -0,0 +1,478 @@
# First value is the value that it is serialized as
# Second value is the path to load it from
SERIALIZABLE_MAPPING = {
("langchain", "schema", "messages", "AIMessage"): (
"langchain_core",
"messages",
"ai",
"AIMessage",
),
("langchain", "schema", "messages", "AIMessageChunk"): (
"langchain_core",
"messages",
"ai",
"AIMessageChunk",
),
("langchain", "schema", "messages", "BaseMessage"): (
"langchain_core",
"messages",
"base",
"BaseMessage",
),
("langchain", "schema", "messages", "BaseMessageChunk"): (
"langchain_core",
"messages",
"base",
"BaseMessageChunk",
),
("langchain", "schema", "messages", "ChatMessage"): (
"langchain_core",
"messages",
"chat",
"ChatMessage",
),
("langchain", "schema", "messages", "FunctionMessage"): (
"langchain_core",
"messages",
"function",
"FunctionMessage",
),
("langchain", "schema", "messages", "HumanMessage"): (
"langchain_core",
"messages",
"human",
"HumanMessage",
),
("langchain", "schema", "messages", "SystemMessage"): (
"langchain_core",
"messages",
"system",
"SystemMessage",
),
("langchain", "schema", "messages", "ToolMessage"): (
"langchain_core",
"messages",
"tool",
"ToolMessage",
),
("langchain", "schema", "agent", "AgentAction"): (
"langchain_core",
"agents",
"AgentAction",
),
("langchain", "schema", "agent", "AgentFinish"): (
"langchain_core",
"agents",
"AgentFinish",
),
("langchain", "schema", "prompt_template", "BasePromptTemplate"): (
"langchain_core",
"prompts",
"base",
"BasePromptTemplate",
),
("langchain", "chains", "llm", "LLMChain"): (
"langchain",
"chains",
"llm",
"LLMChain",
),
("langchain", "prompts", "prompt", "PromptTemplate"): (
"langchain_core",
"prompts",
"prompt",
"PromptTemplate",
),
("langchain", "prompts", "chat", "MessagesPlaceholder"): (
"langchain_core",
"prompts",
"chat",
"MessagesPlaceholder",
),
("langchain", "llms", "openai", "OpenAI"): (
"langchain",
"llms",
"openai",
"OpenAI",
),
("langchain", "prompts", "chat", "ChatPromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"ChatPromptTemplate",
),
("langchain", "prompts", "chat", "HumanMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"HumanMessagePromptTemplate",
),
("langchain", "prompts", "chat", "SystemMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"SystemMessagePromptTemplate",
),
("langchain", "schema", "agent", "AgentActionMessageLog"): (
"langchain_core",
"agents",
"AgentActionMessageLog",
),
("langchain", "schema", "agent", "OpenAIToolAgentAction"): (
"langchain",
"agents",
"output_parsers",
"openai_tools",
"OpenAIToolAgentAction",
),
("langchain", "prompts", "chat", "BaseMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseMessagePromptTemplate",
),
("langchain", "schema", "output", "ChatGeneration"): (
"langchain_core",
"outputs",
"chat_generation",
"ChatGeneration",
),
("langchain", "schema", "output", "Generation"): (
"langchain_core",
"outputs",
"generation",
"Generation",
),
("langchain", "schema", "document", "Document"): (
"langchain_core",
"documents",
"base",
"Document",
),
("langchain", "output_parsers", "fix", "OutputFixingParser"): (
"langchain",
"output_parsers",
"fix",
"OutputFixingParser",
),
("langchain", "prompts", "chat", "AIMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"AIMessagePromptTemplate",
),
("langchain", "output_parsers", "regex", "RegexParser"): (
"langchain",
"output_parsers",
"regex",
"RegexParser",
),
("langchain", "schema", "runnable", "DynamicRunnable"): (
"langchain_core",
"runnables",
"configurable",
"DynamicRunnable",
),
("langchain", "schema", "prompt", "PromptValue"): (
"langchain_core",
"prompt_values",
"PromptValue",
),
("langchain", "schema", "runnable", "RunnableBinding"): (
"langchain_core",
"runnables",
"base",
"RunnableBinding",
),
("langchain", "schema", "runnable", "RunnableBranch"): (
"langchain_core",
"runnables",
"branch",
"RunnableBranch",
),
("langchain", "schema", "runnable", "RunnableWithFallbacks"): (
"langchain_core",
"runnables",
"fallbacks",
"RunnableWithFallbacks",
),
("langchain", "schema", "output_parser", "StrOutputParser"): (
"langchain_core",
"output_parsers",
"string",
"StrOutputParser",
),
("langchain", "chat_models", "openai", "ChatOpenAI"): (
"langchain",
"chat_models",
"openai",
"ChatOpenAI",
),
("langchain", "output_parsers", "list", "CommaSeparatedListOutputParser"): (
"langchain_core",
"output_parsers",
"list",
"CommaSeparatedListOutputParser",
),
("langchain", "schema", "runnable", "RunnableParallel"): (
"langchain_core",
"runnables",
"base",
"RunnableParallel",
),
("langchain", "chat_models", "azure_openai", "AzureChatOpenAI"): (
"langchain",
"chat_models",
"azure_openai",
"AzureChatOpenAI",
),
("langchain", "chat_models", "bedrock", "BedrockChat"): (
"langchain",
"chat_models",
"bedrock",
"BedrockChat",
),
("langchain", "chat_models", "anthropic", "ChatAnthropic"): (
"langchain",
"chat_models",
"anthropic",
"ChatAnthropic",
),
("langchain", "chat_models", "fireworks", "ChatFireworks"): (
"langchain",
"chat_models",
"fireworks",
"ChatFireworks",
),
("langchain", "chat_models", "google_palm", "ChatGooglePalm"): (
"langchain",
"chat_models",
"google_palm",
"ChatGooglePalm",
),
("langchain", "chat_models", "vertexai", "ChatVertexAI"): (
"langchain",
"chat_models",
"vertexai",
"ChatVertexAI",
),
("langchain", "schema", "output", "ChatGenerationChunk"): (
"langchain_core",
"outputs",
"chat_generation",
"ChatGenerationChunk",
),
("langchain", "schema", "messages", "ChatMessageChunk"): (
"langchain_core",
"messages",
"chat",
"ChatMessageChunk",
),
("langchain", "schema", "messages", "HumanMessageChunk"): (
"langchain_core",
"messages",
"human",
"HumanMessageChunk",
),
("langchain", "schema", "messages", "FunctionMessageChunk"): (
"langchain_core",
"messages",
"function",
"FunctionMessageChunk",
),
("langchain", "schema", "messages", "SystemMessageChunk"): (
"langchain_core",
"messages",
"system",
"SystemMessageChunk",
),
("langchain", "schema", "messages", "ToolMessageChunk"): (
"langchain_core",
"messages",
"tool",
"ToolMessageChunk",
),
("langchain", "schema", "output", "GenerationChunk"): (
"langchain_core",
"outputs",
"generation",
"GenerationChunk",
),
("langchain", "llms", "openai", "BaseOpenAI"): (
"langchain",
"llms",
"openai",
"BaseOpenAI",
),
("langchain", "llms", "bedrock", "Bedrock"): (
"langchain",
"llms",
"bedrock",
"Bedrock",
),
("langchain", "llms", "fireworks", "Fireworks"): (
"langchain",
"llms",
"fireworks",
"Fireworks",
),
("langchain", "llms", "google_palm", "GooglePalm"): (
"langchain",
"llms",
"google_palm",
"GooglePalm",
),
("langchain", "llms", "openai", "AzureOpenAI"): (
"langchain",
"llms",
"openai",
"AzureOpenAI",
),
("langchain", "llms", "replicate", "Replicate"): (
"langchain",
"llms",
"replicate",
"Replicate",
),
("langchain", "llms", "vertexai", "VertexAI"): (
"langchain",
"llms",
"vertexai",
"VertexAI",
),
("langchain", "output_parsers", "combining", "CombiningOutputParser"): (
"langchain",
"output_parsers",
"combining",
"CombiningOutputParser",
),
("langchain", "schema", "prompt_template", "BaseChatPromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseChatPromptTemplate",
),
("langchain", "prompts", "chat", "ChatMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"ChatMessagePromptTemplate",
),
("langchain", "prompts", "few_shot_with_templates", "FewShotPromptWithTemplates"): (
"langchain_core",
"prompts",
"few_shot_with_templates",
"FewShotPromptWithTemplates",
),
("langchain", "prompts", "pipeline", "PipelinePromptTemplate"): (
"langchain_core",
"prompts",
"pipeline",
"PipelinePromptTemplate",
),
("langchain", "prompts", "base", "StringPromptTemplate"): (
"langchain_core",
"prompts",
"string",
"StringPromptTemplate",
),
("langchain", "prompts", "base", "StringPromptValue"): (
"langchain_core",
"prompt_values",
"StringPromptValue",
),
("langchain", "prompts", "chat", "BaseStringMessagePromptTemplate"): (
"langchain_core",
"prompts",
"chat",
"BaseStringMessagePromptTemplate",
),
("langchain", "prompts", "chat", "ChatPromptValue"): (
"langchain_core",
"prompt_values",
"ChatPromptValue",
),
("langchain", "prompts", "chat", "ChatPromptValueConcrete"): (
"langchain_core",
"prompt_values",
"ChatPromptValueConcrete",
),
("langchain", "schema", "runnable", "HubRunnable"): (
"langchain",
"runnables",
"hub",
"HubRunnable",
),
("langchain", "schema", "runnable", "RunnableBindingBase"): (
"langchain_core",
"runnables",
"base",
"RunnableBindingBase",
),
("langchain", "schema", "runnable", "OpenAIFunctionsRouter"): (
"langchain",
"runnables",
"openai_functions",
"OpenAIFunctionsRouter",
),
("langchain", "schema", "runnable", "RouterRunnable"): (
"langchain_core",
"runnables",
"router",
"RouterRunnable",
),
("langchain", "schema", "runnable", "RunnablePassthrough"): (
"langchain_core",
"runnables",
"passthrough",
"RunnablePassthrough",
),
("langchain", "schema", "runnable", "RunnableSequence"): (
"langchain_core",
"runnables",
"base",
"RunnableSequence",
),
("langchain", "schema", "runnable", "RunnableEach"): (
"langchain_core",
"runnables",
"base",
"RunnableEach",
),
("langchain", "schema", "runnable", "RunnableEachBase"): (
"langchain_core",
"runnables",
"base",
"RunnableEachBase",
),
("langchain", "schema", "runnable", "RunnableConfigurableAlternatives"): (
"langchain_core",
"runnables",
"configurable",
"RunnableConfigurableAlternatives",
),
("langchain", "schema", "runnable", "RunnableConfigurableFields"): (
"langchain_core",
"runnables",
"configurable",
"RunnableConfigurableFields",
),
("langchain", "schema", "runnable", "RunnableWithMessageHistory"): (
"langchain_core",
"runnables",
"history",
"RunnableWithMessageHistory",
),
("langchain", "schema", "runnable", "RunnableAssign"): (
"langchain_core",
"runnables",
"passthrough",
"RunnableAssign",
),
("langchain", "schema", "runnable", "RunnableRetry"): (
"langchain_core",
"runnables",
"retry",
"RunnableRetry",
),
}

@ -1,4 +1,4 @@
from typing import Any, Literal from typing import Any, List, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -17,6 +17,11 @@ class AIMessage(BaseMessage):
type: Literal["ai"] = "ai" type: Literal["ai"] = "ai"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
AIMessage.update_forward_refs() AIMessage.update_forward_refs()
@ -29,6 +34,11 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
# non-chunk variant. # non-chunk variant.
type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501 type: Literal["AIMessageChunk"] = "AIMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk): if isinstance(other, AIMessageChunk):
if self.example != other.example: if self.example != other.example:

@ -31,6 +31,11 @@ class BaseMessage(Serializable):
"""Return whether this class is serializable.""" """Return whether this class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> ChatPromptTemplate: def __add__(self, other: Any) -> ChatPromptTemplate:
from langchain_core.prompts.chat import ChatPromptTemplate from langchain_core.prompts.chat import ChatPromptTemplate
@ -68,6 +73,11 @@ def merge_content(
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks.""" """A Message chunk, which can be concatenated with other Message chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def _merge_kwargs_dict( def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any] self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]: ) -> Dict[str, Any]:

@ -1,4 +1,4 @@
from typing import Any, Literal from typing import Any, List, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -15,6 +15,11 @@ class ChatMessage(BaseMessage):
type: Literal["chat"] = "chat" type: Literal["chat"] = "chat"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
ChatMessage.update_forward_refs() ChatMessage.update_forward_refs()
@ -27,6 +32,11 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
# non-chunk variant. # non-chunk variant.
type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore type: Literal["ChatMessageChunk"] = "ChatMessageChunk" # type: ignore
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ChatMessageChunk): if isinstance(other, ChatMessageChunk):
if self.role != other.role: if self.role != other.role:

@ -1,4 +1,4 @@
from typing import Any, Literal from typing import Any, List, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -15,6 +15,11 @@ class FunctionMessage(BaseMessage):
type: Literal["function"] = "function" type: Literal["function"] = "function"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
FunctionMessage.update_forward_refs() FunctionMessage.update_forward_refs()
@ -27,6 +32,11 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
# non-chunk variant. # non-chunk variant.
type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment] type: Literal["FunctionMessageChunk"] = "FunctionMessageChunk" # type: ignore[assignment]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, FunctionMessageChunk): if isinstance(other, FunctionMessageChunk):
if self.name != other.name: if self.name != other.name:

@ -1,4 +1,4 @@
from typing import Literal from typing import List, Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -13,6 +13,11 @@ class HumanMessage(BaseMessage):
type: Literal["human"] = "human" type: Literal["human"] = "human"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
HumanMessage.update_forward_refs() HumanMessage.update_forward_refs()
@ -24,3 +29,8 @@ class HumanMessageChunk(HumanMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the # to make sure that the chunk variant can be discriminated from the
# non-chunk variant. # non-chunk variant.
type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501 type: Literal["HumanMessageChunk"] = "HumanMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@ -1,4 +1,4 @@
from typing import Literal from typing import List, Literal
from langchain_core.messages.base import BaseMessage, BaseMessageChunk from langchain_core.messages.base import BaseMessage, BaseMessageChunk
@ -10,6 +10,11 @@ class SystemMessage(BaseMessage):
type: Literal["system"] = "system" type: Literal["system"] = "system"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
SystemMessage.update_forward_refs() SystemMessage.update_forward_refs()
@ -21,3 +26,8 @@ class SystemMessageChunk(SystemMessage, BaseMessageChunk):
# to make sure that the chunk variant can be discriminated from the # to make sure that the chunk variant can be discriminated from the
# non-chunk variant. # non-chunk variant.
type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501 type: Literal["SystemMessageChunk"] = "SystemMessageChunk" # type: ignore[assignment] # noqa: E501
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]

@ -1,4 +1,4 @@
from typing import Any, Literal from typing import Any, List, Literal
from langchain_core.messages.base import ( from langchain_core.messages.base import (
BaseMessage, BaseMessage,
@ -15,6 +15,11 @@ class ToolMessage(BaseMessage):
type: Literal["tool"] = "tool" type: Literal["tool"] = "tool"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
ToolMessage.update_forward_refs() ToolMessage.update_forward_refs()
@ -27,6 +32,11 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
# non-chunk variant. # non-chunk variant.
type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment] type: Literal["ToolMessageChunk"] = "ToolMessageChunk" # type: ignore[assignment]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, ToolMessageChunk): if isinstance(other, ToolMessageChunk):
if self.tool_call_id != other.tool_call_id: if self.tool_call_id != other.tool_call_id:

@ -26,6 +26,11 @@ class CommaSeparatedListOutputParser(ListOutputParser):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "output_parsers", "list"]
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
return ( return (
"Your response should be a list of comma separated values, " "Your response should be a list of comma separated values, "

@ -1,3 +1,5 @@
from typing import List
from langchain_core.output_parsers.transform import BaseTransformOutputParser from langchain_core.output_parsers.transform import BaseTransformOutputParser
@ -9,6 +11,11 @@ class StrOutputParser(BaseTransformOutputParser[str]):
"""Return whether this class is serializable.""" """Return whether this class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output_parser"]
@property @property
def _type(self) -> str: def _type(self) -> str:
"""Return the output parser type for serialization.""" """Return the output parser type for serialization."""

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Literal from typing import Any, Dict, List, Literal
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation from langchain_core.outputs.generation import Generation
@ -27,6 +27,11 @@ class ChatGeneration(Generation):
raise ValueError("Error while initializing ChatGeneration") from e raise ValueError("Error while initializing ChatGeneration") from e
return values return values
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
class ChatGenerationChunk(ChatGeneration): class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other """A ChatGeneration chunk, which can be concatenated with other
@ -41,6 +46,11 @@ class ChatGenerationChunk(ChatGeneration):
type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501 type: Literal["ChatGenerationChunk"] = "ChatGenerationChunk" # type: ignore[assignment] # noqa: E501
"""Type is used exclusively for serialization purposes.""" """Type is used exclusively for serialization purposes."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):
generation_info = ( generation_info = (

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Literal, Optional from typing import Any, Dict, List, Literal, Optional
from langchain_core.load import Serializable from langchain_core.load import Serializable
@ -24,10 +24,20 @@ class Generation(Serializable):
"""Return whether this class is serializable.""" """Return whether this class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
class GenerationChunk(Generation): class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks.""" """A Generation chunk, which can be concatenated with other Generation chunks."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(self, other: GenerationChunk) -> GenerationChunk: def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk): if isinstance(other, GenerationChunk):
generation_info = ( generation_info = (

@ -24,6 +24,11 @@ class PromptValue(Serializable, ABC):
"""Return whether this class is serializable.""" """Return whether this class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "prompt"]
@abstractmethod @abstractmethod
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt value as string.""" """Return prompt value as string."""
@ -40,6 +45,11 @@ class StringPromptValue(PromptValue):
"""Prompt text.""" """Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue" type: Literal["StringPromptValue"] = "StringPromptValue"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"]
def to_string(self) -> str: def to_string(self) -> str:
"""Return prompt as string.""" """Return prompt as string."""
return self.text return self.text
@ -66,6 +76,11 @@ class ChatPromptValue(PromptValue):
"""Return prompt as a list of messages.""" """Return prompt as a list of messages."""
return list(self.messages) return list(self.messages)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
class ChatPromptValueConcrete(ChatPromptValue): class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts. """Chat prompt value which explicitly lists out the message types it accepts.
@ -74,3 +89,8 @@ class ChatPromptValueConcrete(ChatPromptValue):
messages: Sequence[AnyMessage] messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete" type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]

@ -44,6 +44,11 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
default_factory=dict default_factory=dict
) )
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "prompt_template"]
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable.""" """Return whether this class is serializable."""

@ -43,6 +43,11 @@ class BaseMessagePromptTemplate(Serializable, ABC):
"""Return whether or not the class is serializable.""" """Return whether or not the class is serializable."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@abstractmethod @abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]: def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages. """Format messages from kwargs. Should return a list of BaseMessages.
@ -82,6 +87,11 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
variable_name: str variable_name: str
"""Name of variable to use as messages.""" """Name of variable to use as messages."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, **kwargs: Any): def __init__(self, variable_name: str, **kwargs: Any):
return super().__init__(variable_name=variable_name, **kwargs) return super().__init__(variable_name=variable_name, **kwargs)
@ -132,6 +142,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
additional_kwargs: dict = Field(default_factory=dict) additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template.""" """Additional keyword arguments to pass to the prompt template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod @classmethod
def from_template( def from_template(
cls: Type[MessagePromptTemplateT], cls: Type[MessagePromptTemplateT],
@ -221,6 +236,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str role: str
"""Role of the message.""" """Role of the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template. """Format the prompt template.
@ -239,6 +259,11 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user.""" """Human message prompt template. This is a message sent from the user."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template. """Format the prompt template.
@ -255,6 +280,11 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI.""" """AI message prompt template. This is a message sent from the AI."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template. """Format the prompt template.
@ -273,6 +303,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
This is a message that is not sent to the user. This is a message that is not sent to the user.
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template. """Format the prompt template.
@ -368,6 +403,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __add__(self, other: Any) -> ChatPromptTemplate: def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates. """Combine two prompt templates.

@ -42,6 +42,11 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "few_shot_with_templates"]
@root_validator(pre=True) @root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict: def check_examples_and_selector(cls, values: Dict) -> Dict:
"""Check that one and only one of examples/example_selector are provided.""" """Check that one and only one of examples/example_selector are provided."""

@ -28,6 +28,11 @@ class PipelinePromptTemplate(BasePromptTemplate):
pipeline_prompts: List[Tuple[str, BasePromptTemplate]] pipeline_prompts: List[Tuple[str, BasePromptTemplate]]
"""A list of tuples, consisting of a string (`name`) and a Prompt Template.""" """A list of tuples, consisting of a string (`name`) and a Prompt Template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "pipeline"]
@root_validator(pre=True) @root_validator(pre=True)
def get_input_variables(cls, values: Dict) -> Dict: def get_input_variables(cls, values: Dict) -> Dict:
"""Get input variables.""" """Get input variables."""

@ -54,6 +54,11 @@ class PromptTemplate(StringPromptTemplate):
"template_format": self.template_format, "template_format": self.template_format,
} }
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "prompt"]
input_variables: List[str] input_variables: List[str]
"""A list of the names of the variables the prompt template expects.""" """A list of the names of the variables the prompt template expects."""

@ -151,6 +151,11 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
class StringPromptTemplate(BasePromptTemplate, ABC): class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt.""" """String prompt that exposes the format method, returning a prompt."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "base"]
def format_prompt(self, **kwargs: Any) -> PromptValue: def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages.""" """Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs)) return StringPromptValue(text=self.format(**kwargs))

@ -1349,6 +1349,11 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
last: Runnable[Any, Output] last: Runnable[Any, Output]
"""The last runnable in the sequence.""" """The last runnable in the sequence."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def steps(self) -> List[Runnable[Any, Any]]: def steps(self) -> List[Runnable[Any, Any]]:
"""All the runnables that make up the sequence in order.""" """All the runnables that make up the sequence in order."""
@ -1358,10 +1363,6 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -1939,7 +1940,8 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -2705,7 +2707,8 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _invoke( def _invoke(
self, self,
@ -2746,6 +2749,11 @@ class RunnableEach(RunnableEachBase[Input, Output]):
with each element of the input sequence. with each element of the input sequence.
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs)) return RunnableEach(bound=self.bound.bind(**kwargs))
@ -2910,7 +2918,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig: def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
config = merge_configs(self.config, *configs) config = merge_configs(self.config, *configs)
@ -3086,6 +3095,11 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot` runnable_binding.invoke('Say "Parrot-MAGIC"') # Should return `Parrot`
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]: def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""Bind additional kwargs to a Runnable, returning a new Runnable. """Bind additional kwargs to a Runnable, returning a new Runnable.

@ -132,8 +132,8 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch.""" """Get the namespace of the langchain object."""
return cls.__module__.split(".")[:-1] return ["langchain", "schema", "runnable"]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None

@ -53,7 +53,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def InputType(self) -> Type[Input]: def InputType(self) -> Type[Input]:
@ -217,6 +218,11 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
fields: Dict[str, AnyConfigurableField] fields: Dict[str, AnyConfigurableField]
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs( return get_unique_config_specs(
@ -318,6 +324,11 @@ class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by of the form <which.id>==<alternative_key>, eg. a key named "temperature" used by
the alternative named "gpt3" becomes "model==gpt3/temperature".""" the alternative named "gpt3" becomes "model==gpt3/temperature"."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def config_specs(self) -> List[ConfigurableFieldSpec]: def config_specs(self) -> List[ConfigurableFieldSpec]:
with _enums_for_spec_lock: with _enums_for_spec_lock:

@ -125,7 +125,8 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def runnables(self) -> Iterator[Runnable[Input, Output]]: def runnables(self) -> Iterator[Runnable[Input, Output]]:

@ -86,6 +86,11 @@ class RunnableWithMessageHistory(RunnableBindingBase):
output_messages_key: Optional[str] = None output_messages_key: Optional[str] = None
history_messages_key: Optional[str] = None history_messages_key: Optional[str] = None
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def __init__( def __init__(
self, self,
runnable: Runnable[ runnable: Runnable[

@ -167,7 +167,8 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def InputType(self) -> Any: def InputType(self) -> Any:
@ -312,7 +313,8 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def get_input_schema( def get_input_schema(
self, config: Optional[RunnableConfig] = None self, config: Optional[RunnableConfig] = None

@ -114,6 +114,11 @@ class RunnableRetry(RunnableBindingBase[Input, Output]):
max_attempt_number: int = 3 max_attempt_number: int = 3
"""The maximum number of attempts to retry the runnable.""" """The maximum number of attempts to retry the runnable."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
@property @property
def _kwargs_retrying(self) -> Dict[str, Any]: def _kwargs_retrying(self) -> Dict[str, Any]:
kwargs: Dict[str, Any] = dict() kwargs: Dict[str, Any] = dict()

@ -77,7 +77,8 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1] """Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def invoke( def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None self, input: RouterInput, config: Optional[RunnableConfig] = None

File diff suppressed because one or more lines are too long

@ -2029,7 +2029,7 @@ async def test_prompt_with_llm(
): ):
del op["value"]["id"] del op["value"]["id"]
assert stream_log == [ expected = [
RunLogPatch( RunLogPatch(
{ {
"op": "replace", "op": "replace",
@ -2113,6 +2113,7 @@ async def test_prompt_with_llm(
{"op": "replace", "path": "/final_output", "value": "foo"}, {"op": "replace", "path": "/final_output", "value": "foo"},
), ),
] ]
assert stream_log == expected
@freeze_time("2023-01-01") @freeze_time("2023-01-01")

@ -105,6 +105,11 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "anthropic"]
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str: def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model """Format a list of messages into a full prompt for the Anthropic model
Args: Args:

@ -4,7 +4,7 @@ from __future__ import annotations
import logging import logging
import os import os
import warnings import warnings
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, List, Union
from langchain_core.outputs import ChatResult from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@ -94,6 +94,11 @@ class AzureChatOpenAI(ChatOpenAI):
infer if it is a base_url or azure_endpoint and update accordingly. infer if it is a base_url or azure_endpoint and update accordingly.
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "azure_openai"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""

@ -50,6 +50,11 @@ class BedrockChat(BaseChatModel, BedrockBase):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "bedrock"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}

@ -101,6 +101,11 @@ class ChatFireworks(BaseChatModel):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "fireworks"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment.""" """Validate that api key in environment."""

@ -256,6 +256,11 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
def is_lc_serializable(self) -> bool: def is_lc_serializable(self) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "google_palm"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, top_p, and top_k.""" """Validate api key, python package exists, temperature, top_p, and top_k."""

@ -160,6 +160,11 @@ class ChatOpenAI(BaseChatModel):
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"} return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "openai"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}

@ -127,6 +127,11 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
def is_lc_serializable(self) -> bool: def is_lc_serializable(self) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""

@ -357,6 +357,11 @@ class Bedrock(LLM, BedrockBase):
"""Return whether this model can be serialized by Langchain.""" """Return whether this model can be serialized by Langchain."""
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "bedrock"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}

@ -51,6 +51,11 @@ class Fireworks(BaseLLM):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "fireworks"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment.""" """Validate that api key in environment."""

@ -75,6 +75,11 @@ class GooglePalm(BaseLLM, BaseModel):
def is_lc_serializable(self) -> bool: def is_lc_serializable(self) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "google_palm"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists.""" """Validate api key, python package exists."""

@ -149,6 +149,11 @@ class BaseOpenAI(BaseLLM):
def lc_secrets(self) -> Dict[str, str]: def lc_secrets(self) -> Dict[str, str]:
return {"openai_api_key": "OPENAI_API_KEY"} return {"openai_api_key": "OPENAI_API_KEY"}
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@property @property
def lc_attributes(self) -> Dict[str, Any]: def lc_attributes(self) -> Dict[str, Any]:
attributes: Dict[str, Any] = {} attributes: Dict[str, Any] = {}
@ -736,6 +741,11 @@ class OpenAI(BaseOpenAI):
openai = OpenAI(model_name="text-davinci-003") openai = OpenAI(model_name="text-davinci-003")
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@property @property
def _invocation_params(self) -> Dict[str, Any]: def _invocation_params(self) -> Dict[str, Any]:
return {**{"model": self.model_name}, **super()._invocation_params} return {**{"model": self.model_name}, **super()._invocation_params}
@ -794,6 +804,11 @@ class AzureOpenAI(BaseOpenAI):
infer if it is a base_url or azure_endpoint and update accordingly. infer if it is a base_url or azure_endpoint and update accordingly.
""" """
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "openai"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""

@ -70,6 +70,11 @@ class Replicate(LLM):
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "replicate"]
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""

@ -224,6 +224,11 @@ class VertexAI(_VertexAICommon, BaseLLM):
def is_lc_serializable(self) -> bool: def is_lc_serializable(self) -> bool:
return True return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "vertexai"]
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment.""" """Validate that the python package exists in environment."""

@ -9,12 +9,12 @@ from langchain_core.pydantic_v1 import root_validator
class CombiningOutputParser(BaseOutputParser): class CombiningOutputParser(BaseOutputParser):
"""Combine multiple output parsers into one.""" """Combine multiple output parsers into one."""
parsers: List[BaseOutputParser]
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:
return True return True
parsers: List[BaseOutputParser]
@root_validator() @root_validator()
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate the parsers.""" """Validate the parsers."""

@ -97,7 +97,7 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain_core", "langchain",
"prompts", "prompts",
"prompt", "prompt",
"PromptTemplate" "PromptTemplate"
@ -152,7 +152,7 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain_core", "langchain",
"prompts", "prompts",
"chat", "chat",
"ChatPromptTemplate" "ChatPromptTemplate"
@ -166,7 +166,7 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain_core", "langchain",
"prompts", "prompts",
"chat", "chat",
"HumanMessagePromptTemplate" "HumanMessagePromptTemplate"
@ -176,7 +176,7 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain_core", "langchain",
"prompts", "prompts",
"prompt", "prompt",
"PromptTemplate" "PromptTemplate"
@ -236,7 +236,7 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain_core", "langchain",
"prompts", "prompts",
"prompt", "prompt",
"PromptTemplate" "PromptTemplate"

@ -0,0 +1,55 @@
import importlib
import pkgutil
from langchain_core.load.mapping import SERIALIZABLE_MAPPING
def import_all_modules(package_name: str) -> dict:
package = importlib.import_module(package_name)
classes: dict = {}
for attribute_name in dir(package):
attribute = getattr(package, attribute_name)
if hasattr(attribute, "is_lc_serializable") and isinstance(attribute, type):
if (
isinstance(attribute.is_lc_serializable(), bool) # type: ignore
and attribute.is_lc_serializable() # type: ignore
):
key = tuple(attribute.lc_id()) # type: ignore
value = tuple(attribute.__module__.split(".") + [attribute.__name__])
if key in classes and classes[key] != value:
raise ValueError
classes[key] = value
if hasattr(package, "__path__"):
for loader, module_name, is_pkg in pkgutil.walk_packages(
package.__path__, package_name + "."
):
if module_name not in (
"langchain.chains.llm_bash",
"langchain.chains.llm_symbolic_math",
"langchain.tools.python",
"langchain.vectorstores._pgvector_data_models",
):
importlib.import_module(module_name)
new_classes = import_all_modules(module_name)
for k, v in new_classes.items():
if k in classes and classes[k] != v:
raise ValueError
classes[k] = v
return classes
def test_serializable_mapping() -> None:
serializable_modules = import_all_modules("langchain")
missing = set(SERIALIZABLE_MAPPING).difference(serializable_modules)
assert missing == set()
extra = set(serializable_modules).difference(SERIALIZABLE_MAPPING)
assert extra == set()
for k, import_path in serializable_modules.items():
import_dir, import_obj = import_path[:-1], import_path[-1]
# Import module
mod = importlib.import_module(".".join(import_dir))
# Import class
cls = getattr(mod, import_obj)
assert list(k) == cls.lc_id()
Loading…
Cancel
Save