You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/memory/utils.py

39 lines
1.2 KiB
Python

from typing import Any, Dict, List
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
elif isinstance(m, ChatMessage):
role = m.role
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
def get_prompt_input_key(inputs: Dict[str, Any], memory_variables: List[str]) -> str:
# "stop" is a special key that can be passed as input but is not used to
# format the prompt.
prompt_input_keys = list(set(inputs).difference(memory_variables + ["stop"]))
if len(prompt_input_keys) != 1:
raise ValueError(f"One input key expected got {prompt_input_keys}")
return prompt_input_keys[0]