From 362586fe8b5ea350684f96c6eae3c1be6247ce71 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Tue, 14 Mar 2023 18:15:55 -0700 Subject: [PATCH] save messages (#1653) @yakigac this is my alternative to https://github.com/hwchase17/langchain/pull/1648 - thoughts? --- langchain/schema.py | 51 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/langchain/schema.py b/langchain/schema.py index afb4fd4b..286af79e 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -40,24 +40,75 @@ class BaseMessage(BaseModel): content: str additional_kwargs: dict = Field(default_factory=dict) + @property + @abstractmethod + def type(self) -> str: + """Type of the message, used for serialization.""" + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "human" + class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "ai" + class SystemMessage(BaseMessage): """Type of message that is a system message.""" + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "system" + class ChatMessage(BaseMessage): """Type of message with arbitrary speaker.""" role: str + @property + def type(self) -> str: + """Type of the message, used for serialization.""" + return "chat" + + +def _message_to_json(message: BaseMessage) -> dict: + return {"type": message.type, "data": message.dict()} + + +def messages_to_json(messages: List[BaseMessage]) -> List[dict]: + return [_message_to_json(m) for m in messages] + + +def _message_from_json(message: dict) -> BaseMessage: + _type = message["type"] + if _type == "human": + return HumanMessage(**message["data"]) + elif _type == "ai": + return AIMessage(**message["data"]) + elif _type == "system": + return SystemMessage(**message["data"]) + elif _type == "chat": + return ChatMessage(**message["data"]) + else: + raise ValueError(f"Got unexpected type: {_type}") + + +def messages_from_json(messages: List[dict]) -> List[BaseMessage]: + return [_message_from_json(m) for m in messages] + class ChatGeneration(Generation): """Output of a single generation."""