From d9355733621dffcc870b15cae93c4c99fb4b4c43 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Fri, 28 Jul 2023 23:08:33 -0700 Subject: [PATCH] Partial formatting for chat messages (#8450) --- libs/langchain/langchain/prompts/chat.py | 110 +++++++++++++++++- .../tests/unit_tests/prompts/test_chat.py | 25 ++++ 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index f71ac34135..f9b64ec24e 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -57,6 +57,14 @@ class BaseMessagePromptTemplate(Serializable, ABC): """ def __add__(self, other: Any) -> ChatPromptTemplate: + """Combine two prompt templates. + + Args: + other: Another prompt template. + + Returns: + Combined prompt template. + """ prompt = ChatPromptTemplate(messages=[self]) return prompt + other @@ -68,7 +76,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): """Name of variable to use as messages.""" def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """To a BaseMessage. + """Format messages from kwargs. Args: **kwargs: Keyword arguments to use for formatting. @@ -156,10 +164,17 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): @abstractmethod def format(self, **kwargs: Any) -> BaseMessage: - """To a BaseMessage.""" + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: - """Format messages from kwargs. Should return a list of BaseMessages. + """Format messages from kwargs. Args: **kwargs: Keyword arguments to use for formatting. @@ -187,6 +202,14 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): """Role of the message.""" def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ text = self.prompt.format(**kwargs) return ChatMessage( content=text, role=self.role, additional_kwargs=self.additional_kwargs @@ -197,6 +220,14 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): """Human message prompt template. This is a message that is sent to the user.""" def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ text = self.prompt.format(**kwargs) return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) @@ -205,6 +236,14 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): """AI message prompt template. This is a message that is not sent to the user.""" def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ text = self.prompt.format(**kwargs) return AIMessage(content=text, additional_kwargs=self.additional_kwargs) @@ -215,6 +254,14 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): """ def format(self, **kwargs: Any) -> BaseMessage: + """Format the prompt template. + + Args: + **kwargs: Keyword arguments to use for formatting. + + Returns: + Formatted message. + """ text = self.prompt.format(**kwargs) return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) @@ -241,6 +288,15 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): """Base class for chat prompt templates.""" def format(self, **kwargs: Any) -> str: + """Format the chat template into a string. + + Args: + **kwargs: keyword arguments to use for filling in template variables + in all the template messages in this chat template. + + Returns: + formatted string + """ return self.format_prompt(**kwargs).to_string() def format_prompt(self, **kwargs: Any) -> PromptValue: @@ -261,7 +317,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): class ChatPromptTemplate(BaseChatPromptTemplate, ABC): - """A prompt template for a chat models. + """A prompt template for chat models. Use to create flexible templated prompts for chat models. @@ -292,6 +348,14 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): """List of messages consisting of either message prompt templates or messages.""" def __add__(self, other: Any) -> ChatPromptTemplate: + """Combine two prompt templates. + + Args: + other: Another prompt template. + + Returns: + Combined prompt template. + """ # Allow for easy combining if isinstance(other, ChatPromptTemplate): return ChatPromptTemplate(messages=self.messages + other.messages) @@ -497,8 +561,42 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): raise ValueError(f"Unexpected input: {message_template}") return result - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: - raise NotImplementedError + def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: + """Return a new ChatPromptTemplate with some of the input variables already + filled in. + + Args: + **kwargs: keyword arguments to use for filling in template variables. Ought + to be a subset of the input variables. + + Returns: + A new ChatPromptTemplate. + + + Example: + + .. code-block:: python + + from langchain.prompts import ChatPromptTemplate + + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ("human", "Hi I'm {user}"), + ("ai", "Hi there, {user}, I'm {name}."), + ("human", "{input}"), + ] + ) + template2 = template.partial(user="Lucy", name="R2D2") + + template2.format_messages(input="hello") + """ + prompt_dict = self.__dict__.copy() + prompt_dict["input_variables"] = list( + set(self.input_variables).difference(kwargs) + ) + prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} + return type(self)(**prompt_dict) @property def _prompt_type(self) -> str: diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py index 9769ddd7ec..517741151d 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -20,6 +20,7 @@ from langchain.schema.messages import ( BaseMessage, HumanMessage, SystemMessage, + get_buffer_string, ) @@ -288,3 +289,27 @@ def test_convert_to_message_is_strict() -> None: # this test is here to ensure that functionality to interpret `meow` # as a role is NOT added. _convert_to_message(("meow", "question")) + + +def test_chat_message_partial() -> None: + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ("human", "Hi I'm {user}"), + ("ai", "Hi there, {user}, I'm {name}."), + ("human", "{input}"), + ] + ) + template2 = template.partial(user="Lucy", name="R2D2") + with pytest.raises(KeyError): + template.format_messages(input="hello") + + res = template2.format_messages(input="hello") + expected = [ + SystemMessage(content="You are an AI assistant named R2D2."), + HumanMessage(content="Hi I'm Lucy"), + AIMessage(content="Hi there, Lucy, I'm R2D2."), + HumanMessage(content="hello"), + ] + assert res == expected + assert template2.format(input="hello") == get_buffer_string(expected)