save messages (#1653)

@yakigac this is my alternative to
https://github.com/hwchase17/langchain/pull/1648 - thoughts?
tool-patch
Harrison Chase 1 year ago committed by GitHub
parent 63aa28e2a6
commit 362586fe8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,24 +40,75 @@ class BaseMessage(BaseModel):
content: str
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class ChatMessage(BaseMessage):
"""Type of message with arbitrary speaker."""
role: str
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "chat"
def _message_to_json(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_json(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_json(m) for m in messages]
def _message_from_json(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "chat":
return ChatMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def messages_from_json(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_json(m) for m in messages]
class ChatGeneration(Generation):
"""Output of a single generation."""

Loading…
Cancel
Save