You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/prompts/chat.py

250 lines
8.2 KiB
Python

"""Chat prompt template."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import Field, root_validator
from langchain.load.serializable import Serializable
from langchain.memory.buffer import get_buffer_string
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
PromptValue,
SystemMessage,
)
class BaseMessagePromptTemplate(Serializable, ABC):
@property
def lc_serializable(self) -> bool:
return True
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To messages."""
@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template."""
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
variable_name: str
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To a BaseMessage."""
value = kwargs[self.variable_name]
if not isinstance(value, list):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
@property
def input_variables(self) -> List[str]:
"""Input variables for this prompt template."""
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
prompt: StringPromptTemplate
additional_kwargs: dict = Field(default_factory=dict)
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
**kwargs: Any,
) -> MessagePromptTemplateT:
prompt = PromptTemplate.from_template(template, template_format=template_format)
return cls(prompt=prompt, **kwargs)
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs)
@abstractmethod
def format(self, **kwargs: Any) -> BaseMessage:
"""To a BaseMessage."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
return self.prompt.input_variables
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return self.messages
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
input_variables: List[str]
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
messages = values["messages"]
input_vars = set()
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables)
if "input_variables" in values:
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
f"Expected: {input_vars}. "
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = list(input_vars)
return values
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@classmethod
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
messages = [
ChatMessagePromptTemplate(
prompt=PromptTemplate.from_template(template), role=role
)
for role, template in string_messages
]
return cls.from_messages(messages)
@classmethod
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
messages = [
role(prompt=PromptTemplate.from_template(template))
for role, template in string_messages
]
return cls.from_messages(messages)
@classmethod
def from_messages(
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
) -> ChatPromptTemplate:
input_vars = set()
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables)
return cls(input_variables=list(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str:
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(message_template, BaseMessagePromptTemplate):
rel_params = {
k: v
for k, v in kwargs.items()
if k in message_template.input_variables
}
message = message_template.format_messages(**rel_params)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
raise NotImplementedError
@property
def _prompt_type(self) -> str:
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError