mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Support a few list like operations on ChatPromptTemplate (#9077)
Make it easier to work with chat prompt template
This commit is contained in:
parent
e4418d1b7e
commit
44bc89b7bf
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
@ -317,6 +317,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""Format kwargs into a list of messages."""
|
||||
|
||||
|
||||
MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
|
||||
|
||||
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
"""A prompt template for chat models.
|
||||
|
||||
@ -343,9 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables in template messages. Used for validation."""
|
||||
messages: List[
|
||||
Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
]
|
||||
messages: List[MessageLike]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
@ -364,6 +372,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, (list, tuple)):
|
||||
_other = ChatPromptTemplate.from_messages(other)
|
||||
return ChatPromptTemplate(messages=self.messages + _other.messages)
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
return ChatPromptTemplate(messages=self.messages + [prompt])
|
||||
@ -457,16 +468,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
@classmethod
|
||||
def from_messages(
|
||||
cls,
|
||||
messages: Sequence[
|
||||
Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessage,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
],
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@ -556,8 +558,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
return result
|
||||
|
||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
||||
"""Return a new ChatPromptTemplate with some of the input variables already
|
||||
filled in.
|
||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in template variables. Ought
|
||||
@ -592,6 +593,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
|
||||
return type(self)(**prompt_dict)
|
||||
|
||||
def append(self, message: MessageLikeRepresentation) -> None:
|
||||
"""Append message to the end of the chat template.
|
||||
|
||||
Args:
|
||||
message: representation of a message to append.
|
||||
"""
|
||||
self.messages.append(_convert_to_message(message))
|
||||
|
||||
def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None:
|
||||
"""Extend the chat template with a sequence of messages."""
|
||||
self.messages.extend([_convert_to_message(message) for message in messages])
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: int) -> MessageLike:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, index: slice) -> ChatPromptTemplate:
|
||||
...
|
||||
|
||||
def __getitem__(
|
||||
self, index: Union[int, slice]
|
||||
) -> Union[MessageLike, ChatPromptTemplate]:
|
||||
"""Use to index into the chat template."""
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = index.indices(len(self.messages))
|
||||
messages = self.messages[start:stop:step]
|
||||
return ChatPromptTemplate.from_messages(messages)
|
||||
else:
|
||||
return self.messages[index]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get the length of the chat template."""
|
||||
return len(self.messages)
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
"""Name of prompt type."""
|
||||
@ -635,14 +671,7 @@ def _create_template_from_message_type(
|
||||
|
||||
|
||||
def _convert_to_message(
|
||||
message: Union[
|
||||
BaseMessagePromptTemplate,
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessage,
|
||||
Tuple[str, str],
|
||||
Tuple[Type, str],
|
||||
str,
|
||||
]
|
||||
message: MessageLikeRepresentation,
|
||||
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
|
||||
"""Instantiate a message from a variety of message formats.
|
||||
|
||||
|
@ -282,6 +282,42 @@ def test_convert_to_message(
|
||||
assert _convert_to_message(args) == expected
|
||||
|
||||
|
||||
def test_chat_prompt_template_indexing() -> None:
|
||||
message1 = SystemMessage(content="foo")
|
||||
message2 = HumanMessage(content="bar")
|
||||
message3 = HumanMessage(content="baz")
|
||||
template = ChatPromptTemplate.from_messages([message1, message2, message3])
|
||||
assert template[0] == message1
|
||||
assert template[1] == message2
|
||||
|
||||
# Slice starting from index 1
|
||||
slice_template = template[1:]
|
||||
assert slice_template[0] == message2
|
||||
assert len(slice_template) == 2
|
||||
|
||||
|
||||
def test_chat_prompt_template_append_and_extend() -> None:
|
||||
"""Test append and extend methods of ChatPromptTemplate."""
|
||||
message1 = SystemMessage(content="foo")
|
||||
message2 = HumanMessage(content="bar")
|
||||
message3 = HumanMessage(content="baz")
|
||||
template = ChatPromptTemplate.from_messages([message1])
|
||||
template.append(message2)
|
||||
template.append(message3)
|
||||
assert len(template) == 3
|
||||
template.extend([message2, message3])
|
||||
assert len(template) == 5
|
||||
assert template.messages == [
|
||||
message1,
|
||||
message2,
|
||||
message3,
|
||||
message2,
|
||||
message3,
|
||||
]
|
||||
template.append(("system", "hello!"))
|
||||
assert template[-1] == SystemMessagePromptTemplate.from_template("hello!")
|
||||
|
||||
|
||||
def test_convert_to_message_is_strict() -> None:
|
||||
"""Verify that _convert_to_message is strict."""
|
||||
with pytest.raises(ValueError):
|
||||
|
Loading…
Reference in New Issue
Block a user