diff --git a/langchain/schema.py b/langchain/schema.py index afb4fd4b..286af79e 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -40,24 +40,75 @@ class BaseMessage(BaseModel): content: str additional_kwargs: dict = Field(default_factory=dict) + @property + @abstractmethod + def type(self) -> str: + """Type of the message, used for serialization.""" + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "human" + class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "ai" + class SystemMessage(BaseMessage): """Type of message that is a system message.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "system" + class ChatMessage(BaseMessage): """Type of message with arbitrary speaker.""" role: str + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "chat" + + +def _message_to_json(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_json(messages: List[BaseMessage]) -> List[dict]: + return [_message_to_json(m) for m in messages] + + +def _message_from_json(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected type: {_type}") + + +def messages_from_json(messages: List[dict]) -> List[BaseMessage]: + return [_message_from_json(m) for m in messages] + class ChatGeneration(Generation): """Output of a single generation."""