2023-12-11 21:53:30 +00:00
|
|
|
from typing import List
|
|
|
|
|
|
|
|
from langchain_core.messages import (
|
|
|
|
AIMessage,
|
|
|
|
BaseMessage,
|
|
|
|
ChatMessage,
|
|
|
|
HumanMessage,
|
|
|
|
SystemMessage,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_one_message_to_text_llama(message: BaseMessage) -> str:
|
|
|
|
if isinstance(message, ChatMessage):
|
|
|
|
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
|
|
|
|
elif isinstance(message, HumanMessage):
|
|
|
|
message_text = f"[INST] {message.content} [/INST]"
|
|
|
|
elif isinstance(message, AIMessage):
|
|
|
|
message_text = f"{message.content}"
|
|
|
|
elif isinstance(message, SystemMessage):
|
|
|
|
message_text = f"<<SYS>> {message.content} <</SYS>>"
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Got unknown type {message}")
|
|
|
|
return message_text
|
|
|
|
|
|
|
|
|
|
|
|
def convert_messages_to_prompt_llama(messages: List[BaseMessage]) -> str:
|
2023-12-19 13:58:24 +00:00
|
|
|
"""Convert a list of messages to a prompt for llama."""
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
return "\n".join(
|
|
|
|
[_convert_one_message_to_text_llama(message) for message in messages]
|
|
|
|
)
|