mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Partial formatting for chat messages (#8450)
This commit is contained in:
parent
3314f54383
commit
d935573362
@ -57,6 +57,14 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
"""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
|
||||
@ -68,7 +76,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Name of variable to use as messages."""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""To a BaseMessage.
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
@ -156,10 +164,17 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""To a BaseMessage."""
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format messages from kwargs. Should return a list of BaseMessages.
|
||||
"""Format messages from kwargs.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
@ -187,6 +202,14 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Role of the message."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return ChatMessage(
|
||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||
@ -197,6 +220,14 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""Human message prompt template. This is a message that is sent to the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@ -205,6 +236,14 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""AI message prompt template. This is a message that is not sent to the user."""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@ -215,6 +254,14 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
"""
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
"""Format the prompt template.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
Formatted message.
|
||||
"""
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@ -241,6 +288,15 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Base class for chat prompt templates."""
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the chat template into a string.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables
|
||||
in all the template messages in this chat template.
|
||||
|
||||
Returns:
|
||||
formatted string
|
||||
"""
|
||||
return self.format_prompt(**kwargs).to_string()
|
||||
|
||||
def format_prompt(self, **kwargs: Any) -> PromptValue:
|
||||
@ -261,7 +317,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
"""A prompt template for a chat models.
|
||||
"""A prompt template for chat models.
|
||||
|
||||
Use to create flexible templated prompts for chat models.
|
||||
|
||||
@ -292,6 +348,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
|
||||
Args:
|
||||
other: Another prompt template.
|
||||
|
||||
Returns:
|
||||
Combined prompt template.
|
||||
"""
|
||||
# Allow for easy combining
|
||||
if isinstance(other, ChatPromptTemplate):
|
||||
return ChatPromptTemplate(messages=self.messages + other.messages)
|
||||
@ -497,8 +561,42 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
raise ValueError(f"Unexpected input: {message_template}")
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
|
||||
raise NotImplementedError
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||
"""Return a new ChatPromptTemplate with some of the input variables already
|
||||
filled in.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables. Ought
|
||||
to be a subset of the input variables.
|
||||
|
||||
Returns:
|
||||
A new ChatPromptTemplate.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
("human", "Hi I'm {user}"),
|
||||
("ai", "Hi there, {user}, I'm {name}."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
template2 = template.partial(user="Lucy", name="R2D2")
|
||||
|
||||
template2.format_messages(input="hello")
|
||||
"""
|
||||
prompt_dict = self.__dict__.copy()
|
||||
prompt_dict["input_variables"] = list(
|
||||
set(self.input_variables).difference(kwargs)
|
||||
)
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
|
@ -20,6 +20,7 @@ from langchain.schema.messages import (
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
|
||||
@ -288,3 +289,27 @@ def test_convert_to_message_is_strict() -> None:
|
||||
# this test is here to ensure that functionality to interpret `meow`
|
||||
# as a role is NOT added.
|
||||
_convert_to_message(("meow", "question"))
|
||||
|
||||
|
||||
def test_chat_message_partial() -> None:
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
("human", "Hi I'm {user}"),
|
||||
("ai", "Hi there, {user}, I'm {name}."),
|
||||
("human", "{input}"),
|
||||
]
|
||||
)
|
||||
template2 = template.partial(user="Lucy", name="R2D2")
|
||||
with pytest.raises(KeyError):
|
||||
template.format_messages(input="hello")
|
||||
|
||||
res = template2.format_messages(input="hello")
|
||||
expected = [
|
||||
SystemMessage(content="You are an AI assistant named R2D2."),
|
||||
HumanMessage(content="Hi I'm Lucy"),
|
||||
AIMessage(content="Hi there, Lucy, I'm R2D2."),
|
||||
HumanMessage(content="hello"),
|
||||
]
|
||||
assert res == expected
|
||||
assert template2.format(input="hello") == get_buffer_string(expected)
|
||||
|
Loading…
Reference in New Issue
Block a user