You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/client/utils.py

43 lines
1.2 KiB
Python

"""Client Utils."""
import re
from typing import Dict, List, Optional, Sequence, Type, Union
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
"Human": HumanMessage,
"AI": AIMessage,
"System": SystemMessage,
}
def parse_chat_messages(
input_text: str, roles: Optional[Sequence[str]] = None
) -> List[BaseMessage]:
"""Parse chat messages from a string. This is not robust."""
roles = roles or ["Human", "AI", "System"]
roles_pattern = "|".join(roles)
pattern = (
rf"(?P<entity>{roles_pattern}): (?P<message>"
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
)
matches = re.finditer(pattern, input_text, re.MULTILINE)
results: List[BaseMessage] = []
for match in matches:
entity = match.group("entity")
message = match.group("message").rstrip("\n")
if entity in _RESOLUTION_MAP:
results.append(_RESOLUTION_MAP[entity](content=message))
else:
results.append(ChatMessage(role=entity, content=message))
return results