ChatPromptTemplate: Update doc-strings, update from_role_strings behavior (#8308)

* Update doc-strings in ChatPromptTemplate
* Update from_role_strings classmethod to use well known roles
This commit is contained in:
Eugene Yurtsev 2023-07-26 15:02:36 -04:00 committed by GitHub
parent 2c2fd9ff13
commit 862e9aed66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 33 deletions

View File

@ -220,7 +220,10 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
class ChatPromptValue(PromptValue):
"""Chat prompt value."""
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: List[BaseMessage]
"""List of messages."""
@ -258,12 +261,65 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""Chat prompt template. This is a prompt that is sent to the user."""
"""Use to create flexible templated prompts for chat models.
Examples:
Instantiation from role strings:
.. code-block:: python
from langchain.prompts import ChatPromptTemplate
prompt_template = ChatPromptTemplate.from_role_strings(
[
('system', "You are a helpful bot. Your name is {bot_name}."),
('human', "{user_input}")
]
)
prompt_template.format_messages(
bot_name="bobby",
user_input="Hello! What is your name?"
)
Instantiation from messages:
This is useful if it's important to distinguish between messages that
are templates and messages that are already formatted.
.. code-block:: python
from langchain.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AIMessage
prompt_template = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(
"You are a helpful bot. Your name is {bot_name}."
),
AIMessage(content="Hello!"), # Already formatted message
HumanMessagePromptTemplate.from_template(
"{user_input}"
),
]
)
prompt_template.format_messages(
bot_name="bobby",
user_input="Hello! What is your name?"
)
"""
input_variables: List[str]
"""List of input variables."""
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
"""List of messages."""
"""List of messages consisting of either message prompt templates or messages."""
def __add__(self, other: Any) -> ChatPromptTemplate:
# Allow for easy combining
@ -279,9 +335,10 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@root_validator(pre=True)
def validate_input_variables(cls, values: dict) -> dict:
"""
Validate input variables. If input_variables is not set, it will be set to
the union of all input variables in the messages.
"""Validate input variables.
If input_variables is not set, it will be set to the union of
all input variables in the messages.
Args:
values: values to validate.
@ -309,10 +366,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
"""Create a class from a template.
"""Create a chat prompt template from a template string.
Creates a chat template consisting of a single message assumed to be from
the human.
Args:
template: template string.
template: template string
**kwargs: keyword arguments to pass to the constructor.
Returns:
@ -328,31 +388,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
The roles "human", "ai", and "system" are special and will be converted
to the appropriate message class. All other roles will be converted to a
generic ChatMessagePromptTemplate.
Args:
string_messages: list of (role, template) tuples.
Returns:
A new instance of this class.
a chat prompt template
"""
messages = [
ChatMessagePromptTemplate(
prompt=PromptTemplate.from_template(template), role=role
)
for role, template in string_messages
]
messages: List[BaseMessagePromptTemplate] = []
message: BaseMessagePromptTemplate
for role, template in string_messages:
if role == "human":
message = HumanMessagePromptTemplate.from_template(template)
elif role == "ai":
message = AIMessagePromptTemplate.from_template(template)
elif role == "system":
message = SystemMessagePromptTemplate.from_template(template)
else:
message = ChatMessagePromptTemplate.from_template(template, role=role)
messages.append(message)
return cls.from_messages(messages)
@classmethod
def from_strings(
cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]]
) -> ChatPromptTemplate:
"""Create a class from a list of (role, template) tuples.
"""Create a class from a list of (role class, template) tuples.
Args:
string_messages: list of (role, template) tuples.
string_messages: list of (role class, template) tuples.
Returns:
A new instance of this class.
a chat prompt template
"""
messages = [
role(prompt=PromptTemplate.from_template(template))
@ -364,14 +434,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
def from_messages(
cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]]
) -> ChatPromptTemplate:
"""
Create a class from a list of messages.
"""Create a chat template from a sequence of messages.
Args:
messages: list of messages.
messages: sequence of templated or regular messages
Returns:
A new instance of this class.
a chat prompt template
"""
input_vars = set()
for message in messages:
@ -380,17 +449,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return cls(input_variables=list(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""
Format kwargs into a list of messages.
"""Format the chat template into a list of finalized messages.
Args:
**kwargs: keyword arguments to use for formatting.
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
List of messages.
list of formatted messages
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
@ -414,11 +492,11 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
return "chat"
def save(self, file_path: Union[Path, str]) -> None:
"""
Save prompt to file.
"""Save prompt to file.
Args:
file_path: path to file.

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import List
from typing import List, Union
import pytest
@ -7,13 +7,19 @@ from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatMessage,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
)
def create_messages() -> List[BaseMessagePromptTemplate]:
@ -133,7 +139,9 @@ def test_chat_prompt_template_from_messages() -> None:
def test_chat_prompt_template_with_messages() -> None:
messages = create_messages() + [HumanMessage(content="foo")]
messages: List[
Union[BaseMessagePromptTemplate, BaseMessage]
] = create_messages() + [HumanMessage(content="foo")]
chat_prompt_template = ChatPromptTemplate.from_messages(messages)
assert sorted(chat_prompt_template.input_variables) == sorted(
["context", "foo", "bar"]
@ -175,7 +183,7 @@ def test_chat_valid_with_partial_variables() -> None:
input_variables=["question", "context"],
partial_variables={"formatins": "some structure"},
)
assert set(prompt.input_variables) == set(["question", "context"])
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}
@ -188,5 +196,25 @@ def test_chat_valid_infer_variables() -> None:
prompt = ChatPromptTemplate(
messages=messages, partial_variables={"formatins": "some structure"}
)
assert set(prompt.input_variables) == set(["question", "context"])
assert set(prompt.input_variables) == {"question", "context"}
assert prompt.partial_variables == {"formatins": "some structure"}
def test_chat_from_role_strings() -> None:
"""Test instantiation of chat template from role strings."""
template = ChatPromptTemplate.from_role_strings(
[
("system", "You are a bot."),
("ai", "hello!"),
("human", "{question}"),
("other", "{quack}"),
]
)
messages = template.format_messages(question="How are you?", quack="duck")
assert messages == [
SystemMessage(content="You are a bot."),
AIMessage(content="hello!"),
HumanMessage(content="How are you?"),
ChatMessage(content="duck", role="other"),
]