Support a few list like operations on ChatPromptTemplate (#9077)

Make it easier to work with chat prompt template
This commit is contained in:
Eugene Yurtsev 2023-08-11 14:49:51 -04:00 committed by GitHub
parent e4418d1b7e
commit 44bc89b7bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 24 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, overload
from pydantic import Field, root_validator
@ -317,6 +317,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Format kwargs into a list of messages."""
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
MessageLikeRepresentation = Union[
MessageLike,
Tuple[str, str],
Tuple[Type, str],
str,
]
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
"""A prompt template for chat models.
@ -343,9 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
input_variables: List[str]
"""List of input variables in template messages. Used for validation."""
messages: List[
Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
]
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
def __add__(self, other: Any) -> ChatPromptTemplate:
@ -364,6 +372,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, (list, tuple)):
_other = ChatPromptTemplate.from_messages(other)
return ChatPromptTemplate(messages=self.messages + _other.messages)
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
return ChatPromptTemplate(messages=self.messages + [prompt])
@ -457,16 +468,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@classmethod
def from_messages(
cls,
messages: Sequence[
Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
],
messages: Sequence[MessageLikeRepresentation],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@ -556,8 +558,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
"""Return a new ChatPromptTemplate with some of the input variables already
filled in.
"""Get a new ChatPromptTemplate with some input variables already filled in.
Args:
**kwargs: keyword arguments to use for filling in template variables. Ought
@ -592,6 +593,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def append(self, message: MessageLikeRepresentation) -> None:
"""Append message to the end of the chat template.
Args:
message: representation of a message to append.
"""
self.messages.append(_convert_to_message(message))
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
"""Extend the chat template with a sequence of messages."""
self.messages.extend([_convert_to_message(message) for message in messages])
@overload
def __getitem__(self, index: int) -> MessageLike:
...
@overload
def __getitem__(self, index: slice) -> ChatPromptTemplate:
...
def __getitem__(
self, index: Union[int, slice]
) -> Union[MessageLike, ChatPromptTemplate]:
"""Use to index into the chat template."""
if isinstance(index, slice):
start, stop, step = index.indices(len(self.messages))
messages = self.messages[start:stop:step]
return ChatPromptTemplate.from_messages(messages)
else:
return self.messages[index]
def __len__(self) -> int:
"""Get the length of the chat template."""
return len(self.messages)
@property
def _prompt_type(self) -> str:
"""Name of prompt type."""
@ -635,14 +671,7 @@ def _create_template_from_message_type(
def _convert_to_message(
message: Union[
BaseMessagePromptTemplate,
BaseChatPromptTemplate,
BaseMessage,
Tuple[str, str],
Tuple[Type, str],
str,
]
message: MessageLikeRepresentation,
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.

View File

@ -282,6 +282,42 @@ def test_convert_to_message(
assert _convert_to_message(args) == expected
def test_chat_prompt_template_indexing() -> None:
message1 = SystemMessage(content="foo")
message2 = HumanMessage(content="bar")
message3 = HumanMessage(content="baz")
template = ChatPromptTemplate.from_messages([message1, message2, message3])
assert template[0] == message1
assert template[1] == message2
# Slice starting from index 1
slice_template = template[1:]
assert slice_template[0] == message2
assert len(slice_template) == 2
def test_chat_prompt_template_append_and_extend() -> None:
"""Test append and extend methods of ChatPromptTemplate."""
message1 = SystemMessage(content="foo")
message2 = HumanMessage(content="bar")
message3 = HumanMessage(content="baz")
template = ChatPromptTemplate.from_messages([message1])
template.append(message2)
template.append(message3)
assert len(template) == 3
template.extend([message2, message3])
assert len(template) == 5
assert template.messages == [
message1,
message2,
message3,
message2,
message3,
]
template.append(("system", "hello!"))
assert template[-1] == SystemMessagePromptTemplate.from_template("hello!")
def test_convert_to_message_is_strict() -> None:
"""Verify that _convert_to_message is strict."""
with pytest.raises(ValueError):