|
|
|
@ -40,24 +40,75 @@ class BaseMessage(BaseModel):
|
|
|
|
|
content: str
|
|
|
|
|
additional_kwargs: dict = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def type(self) -> str:
|
|
|
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HumanMessage(BaseMessage):
|
|
|
|
|
"""Type of message that is spoken by the human."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self) -> str:
|
|
|
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
return "human"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AIMessage(BaseMessage):
|
|
|
|
|
"""Type of message that is spoken by the AI."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self) -> str:
|
|
|
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
return "ai"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SystemMessage(BaseMessage):
|
|
|
|
|
"""Type of message that is a system message."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self) -> str:
|
|
|
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
return "system"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseMessage):
|
|
|
|
|
"""Type of message with arbitrary speaker."""
|
|
|
|
|
|
|
|
|
|
role: str
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def type(self) -> str:
|
|
|
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
return "chat"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _message_to_json(message: BaseMessage) -> dict:
|
|
|
|
|
return {"type": message.type, "data": message.dict()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def messages_to_json(messages: List[BaseMessage]) -> List[dict]:
|
|
|
|
|
return [_message_to_json(m) for m in messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _message_from_json(message: dict) -> BaseMessage:
|
|
|
|
|
_type = message["type"]
|
|
|
|
|
if _type == "human":
|
|
|
|
|
return HumanMessage(**message["data"])
|
|
|
|
|
elif _type == "ai":
|
|
|
|
|
return AIMessage(**message["data"])
|
|
|
|
|
elif _type == "system":
|
|
|
|
|
return SystemMessage(**message["data"])
|
|
|
|
|
elif _type == "chat":
|
|
|
|
|
return ChatMessage(**message["data"])
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Got unexpected type: {_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def messages_from_json(messages: List[dict]) -> List[BaseMessage]:
|
|
|
|
|
return [_message_from_json(m) for m in messages]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatGeneration(Generation):
|
|
|
|
|
"""Output of a single generation."""
|
|
|
|
|