forked from Archives/langchain
save messages (#1653)
@yakigac this is my alternative to https://github.com/hwchase17/langchain/pull/1648 - thoughts?
This commit is contained in:
parent
63aa28e2a6
commit
362586fe8b
@ -40,24 +40,75 @@ class BaseMessage(BaseModel):
|
||||
content: str
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "human"
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "ai"
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "system"
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
"""Type of message with arbitrary speaker."""
|
||||
|
||||
role: str
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
return "chat"
|
||||
|
||||
|
||||
def _message_to_json(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_json(messages: List[BaseMessage]) -> List[dict]:
|
||||
return [_message_to_json(m) for m in messages]
|
||||
|
||||
|
||||
def _message_from_json(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
elif _type == "ai":
|
||||
return AIMessage(**message["data"])
|
||||
elif _type == "system":
|
||||
return SystemMessage(**message["data"])
|
||||
elif _type == "chat":
|
||||
return ChatMessage(**message["data"])
|
||||
else:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
||||
def messages_from_json(messages: List[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_json(m) for m in messages]
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
"""Output of a single generation."""
|
||||
|
Loading…
Reference in New Issue
Block a user