forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from typing import Any, Dict, List
|
|
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ChatMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
|
|
|
|
def get_buffer_string(
|
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
) -> str:
|
|
"""Get buffer string of messages."""
|
|
string_messages = []
|
|
for m in messages:
|
|
if isinstance(m, HumanMessage):
|
|
role = human_prefix
|
|
elif isinstance(m, AIMessage):
|
|
role = ai_prefix
|
|
elif isinstance(m, SystemMessage):
|
|
role = "System"
|
|
elif isinstance(m, ChatMessage):
|
|
role = m.role
|
|
else:
|
|
raise ValueError(f"Got unsupported message type: {m}")
|
|
string_messages.append(f"{role}: {m.content}")
|
|
return "\n".join(string_messages)
|
|
|
|
|
|
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
|
|
# "stop" is a special key that can be passed as input but is not used to
|
|
# format the prompt.
|
|
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
|
|
if len(prompt_input_keys) != 1:
|
|
raise ValueError(f"One input key expected got {prompt_input_keys}")
|
|
return prompt_input_keys[0]
|