You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/langchain_core/prompts/chat.py

806 lines
26 KiB
Python

"""Chat prompt template."""
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from langchain_core._api import deprecated
from langchain_core.load import Serializable
from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
@property
@abstractmethod
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variables.
"""
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages."""
variable_name: str
"""Name of variable to use as messages."""
optional: bool = False
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __init__(self, variable_name: str, *, optional: bool = False, **kwargs: Any):
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessage.
"""
value = (
kwargs.get(self.variable_name, [])
if self.optional
else kwargs[self.variable_name]
)
if not isinstance(value, list):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
@property
def input_variables(self) -> List[str]:
"""Input variables for this prompt template.
Returns:
List of input variable names.
"""
return [self.variable_name] if not self.optional else []
def pretty_repr(self, html: bool = False) -> str:
var = "{" + self.variable_name + "}"
if html:
title = get_msg_title_repr("Messages Placeholder", bold=True)
var = get_colored_text(var, "yellow")
else:
title = get_msg_title_repr("Messages Placeholder")
return f"{title}\n\n{var}"
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
"""Type variable for message prompt templates."""
class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
"""Base class for message prompt templates that use a string prompt template."""
prompt: StringPromptTemplate
"""String prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
template: str,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_template(
template,
template_format=template_format,
partial_variables=partial_variables,
)
return cls(prompt=prompt, **kwargs)
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt = PromptTemplate.from_file(template_file, input_variables)
return cls(prompt=prompt, **kwargs)
@abstractmethod
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
return self.prompt.input_variables
def pretty_repr(self, html: bool = False) -> str:
# TODO: Handle partials
title = self.__class__.__name__.replace("MessagePromptTemplate", " Message")
title = get_msg_title_repr(title, bold=html)
return f"{title}\n\n{self.prompt.pretty_repr(html=html)}"
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Chat message prompt template."""
role: str
"""Role of the message."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
@property
def lc_attributes(self) -> Dict:
"""
Return a list of attribute names that should be included in the
serialized kwargs. These attributes must be accepted by the
constructor.
"""
return {"input_variables": self.input_variables}
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""
Format prompt. Should return a PromptValue.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
PromptValue.
"""
messages = self.format_messages(**kwargs)
return ChatPromptValue(messages=messages)
@abstractmethod
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages."""
def pretty_repr(self, html: bool = False) -> str:
"""Human-readable representation."""
raise NotImplementedError
def pretty_print(self) -> None:
print(self.pretty_repr(html=is_interactive_env()))
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
str,
]
class ChatPromptTemplate(BaseChatPromptTemplate):
"""A prompt template for chat models.
Use to create flexible templated prompts for chat models.
Examples:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages([
("system", "You are a helpful AI bot. Your name is {name}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{user_input}"),
])
messages = template.format_messages(
name="Bob",
user_input="What is your name?"
)
"""
input_variables: List[str]
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"]
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
else:
raise NotImplementedError(f"Unsupported operand type for +: {type(other)}")
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""Validate input variables.
If input_variables is not set, it will be set to the union of
all input variables in the messages.
Args:
values: values to validate.
Returns:
Validated values.
"""
messages = values["messages"]
input_vars = set()
input_types: Dict[str, Any] = values.get("input_types", {})
for message in messages:
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables)
if isinstance(message, MessagesPlaceholder):
if message.variable_name not in input_types:
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "
f"Expected: {input_vars}. "
f"Got: {values['input_variables']}"
)
else:
values["input_variables"] = sorted(input_vars)
values["input_types"] = input_types
return values
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a chat prompt template from a template string.
Creates a chat template consisting of a single message assumed to be from
the human.
Args:
template: template string
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
return cls.from_messages([message])
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_role_strings(
cls, string_messages: List[Tuple[str, str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role, template) tuples.
Args:
string_messages: list of (role, template) tuples.
Returns:
a chat prompt template
"""
return cls(
messages=[
ChatMessagePromptTemplate.from_template(template, role=role)
for role, template in string_messages
]
)
@classmethod
@deprecated("0.0.260", alternative="from_messages classmethod", pending=True)
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a chat prompt template from a list of (role class, template) tuples.
Args:
string_messages: list of (role class, template) tuples.
Returns:
a chat prompt template
"""
return cls.from_messages(string_messages)
@classmethod
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
("human", "That's good to hear."),
])
Instantiation from mixed message formats:
.. code-block:: python
template = ChatPromptTemplate.from_messages([
SystemMessage(content="hello"),
("human", "Hello, how are you?"),
])
Args:
messages: sequence of message representations.
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}"
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
for _message in _messages:
if isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format the chat template into a list of finalized messages.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
list of formatted messages
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
message = message_template.format_messages(**kwargs)
result.extend(message)
else:
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
"""Get a new ChatPromptTemplate with some input variables already filled in.
Args:
**kwargs: keyword arguments to use for filling in template variables. Ought
to be a subset of the input variables.
Returns:
A new ChatPromptTemplate.
Example:
.. code-block:: python
from langchain_core.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"),
("ai", "Hi there, {user}, I'm {name}."),
("human", "{input}"),
]
)
template2 = template.partial(user="Lucy", name="R2D2")
template2.format_messages(input="hello")
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def append(self, message: MessageLikeRepresentation) -> None:
"""Append message to the end of the chat template.
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> ChatPromptTemplate:
...
def __getitem__(
self, index: Union[int, slice]
) -> Union[MessageLike, ChatPromptTemplate]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return ChatPromptTemplate.from_messages(messages)
else:
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
"""Save prompt to file.
Args:
file_path: path to file.
"""
raise NotImplementedError()
def pretty_repr(self, html: bool = False) -> str:
# TODO: handle partials
return "\n\n".join(msg.pretty_repr(html=html) for msg in self.messages)
def _create_template_from_message_type(
message_type: str, template: str
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Args:
message_type: str the type of the message template (e.g., "human", "ai", etc.)
template: str the template string.
Returns:
a message prompt template of the appropriate type.
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- 2-tuple of (message class, template)
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
_message: Union[
BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate
] = message
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
else:
_message = message_type_str(prompt=PromptTemplate.from_template(template))
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message