Partial formatting for chat messages (#8450)

This commit is contained in:
William FH 2023-07-28 23:08:33 -07:00 committed by GitHub
parent 3314f54383
commit d935573362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 129 additions and 6 deletions

View File

@ -57,6 +57,14 @@ class BaseMessagePromptTemplate(Serializable, ABC):
"""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
@ -68,7 +76,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Name of variable to use as messages."""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""To a BaseMessage.
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
@ -156,10 +164,17 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@abstractmethod
def format(self, **kwargs: Any) -> BaseMessage:
"""To a BaseMessage."""
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs. Should return a list of BaseMessages.
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
@ -187,6 +202,14 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Role of the message."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return ChatMessage(
content=text, role=self.role, additional_kwargs=self.additional_kwargs
@ -197,6 +220,14 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""Human message prompt template. This is a message that is sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
@ -205,6 +236,14 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""AI message prompt template. This is a message that is not sent to the user."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
@ -215,6 +254,14 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
@ -241,6 +288,15 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
Args:
**kwargs: keyword arguments to use for filling in template variables
in all the template messages in this chat template.
Returns:
formatted string
"""
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
@ -261,7 +317,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""A prompt template for a chat models.
"""A prompt template for chat models.
Use to create flexible templated prompts for chat models.
@ -292,6 +348,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""List of messages consisting of either message prompt templates or messages."""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
Args:
other: Another prompt template.
Returns:
Combined prompt template.
"""
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
@ -497,8 +561,42 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
raise NotImplementedError
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
"""Return a new ChatPromptTemplate with some of the input variables already
filled in.
Args:
**kwargs: keyword arguments to use for filling in template variables. Ought
to be a subset of the input variables.
Returns:
A new ChatPromptTemplate.
Example:
.. code-block:: python
from langchain.prompts import ChatPromptTemplate
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"),
("ai", "Hi there, {user}, I'm {name}."),
("human", "{input}"),
]
)
template2 = template.partial(user="Lucy", name="R2D2")
template2.format_messages(input="hello")
"""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
@property
def _prompt_type(self) -> str:

View File

@ -20,6 +20,7 @@ from langchain.schema.messages import (
BaseMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
@ -288,3 +289,27 @@ def test_convert_to_message_is_strict() -> None:
# this test is here to ensure that functionality to interpret `meow`
# as a role is NOT added.
_convert_to_message(("meow", "question"))
def test_chat_message_partial() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
("human", "Hi I'm {user}"),
("ai", "Hi there, {user}, I'm {name}."),
("human", "{input}"),
]
)
template2 = template.partial(user="Lucy", name="R2D2")
with pytest.raises(KeyError):
template.format_messages(input="hello")
res = template2.format_messages(input="hello")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(content="Hi I'm Lucy"),
AIMessage(content="Hi there, Lucy, I'm R2D2."),
HumanMessage(content="hello"),
]
assert res == expected
assert template2.format(input="hello") == get_buffer_string(expected)