diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index aff7b6da4a..01292dbfdd 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -3,7 +3,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, overload from pydantic import Field, root_validator @@ -317,6 +317,16 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): """Format kwargs into a list of messages.""" +MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] + +MessageLikeRepresentation = Union[ + MessageLike, + Tuple[str, str], + Tuple[Type, str], + str, +] + + class ChatPromptTemplate(BaseChatPromptTemplate, ABC): """A prompt template for chat models. @@ -343,9 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): input_variables: List[str] """List of input variables in template messages. Used for validation.""" - messages: List[ - Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] - ] + messages: List[MessageLike] """List of messages consisting of either message prompt templates or messages.""" def __add__(self, other: Any) -> ChatPromptTemplate: @@ -364,6 +372,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate) ): return ChatPromptTemplate(messages=self.messages + [other]) + elif isinstance(other, (list, tuple)): + _other = ChatPromptTemplate.from_messages(other) + return ChatPromptTemplate(messages=self.messages + _other.messages) elif isinstance(other, str): prompt = HumanMessagePromptTemplate.from_template(other) return ChatPromptTemplate(messages=self.messages + [prompt]) @@ -457,16 +468,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): @classmethod def from_messages( cls, - messages: Sequence[ - Union[ - BaseMessagePromptTemplate, - BaseChatPromptTemplate, - BaseMessage, - Tuple[str, str], - Tuple[Type, str], - str, - ] - ], + messages: Sequence[MessageLikeRepresentation], ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -556,8 +558,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): return result def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: - """Return a new ChatPromptTemplate with some of the input variables already - filled in. + """Get a new ChatPromptTemplate with some input variables already filled in. Args: **kwargs: keyword arguments to use for filling in template variables. Ought @@ -592,6 +593,41 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs} return type(self)(**prompt_dict) + def append(self, message: MessageLikeRepresentation) -> None: + """Append message to the end of the chat template. + + Args: + message: representation of a message to append. + """ + self.messages.append(_convert_to_message(message)) + + def extend(self, messages: Sequence[MessageLikeRepresentation]) -> None: + """Extend the chat template with a sequence of messages.""" + self.messages.extend([_convert_to_message(message) for message in messages]) + + @overload + def __getitem__(self, index: int) -> MessageLike: + ... + + @overload + def __getitem__(self, index: slice) -> ChatPromptTemplate: + ... + + def __getitem__( + self, index: Union[int, slice] + ) -> Union[MessageLike, ChatPromptTemplate]: + """Use to index into the chat template.""" + if isinstance(index, slice): + start, stop, step = index.indices(len(self.messages)) + messages = self.messages[start:stop:step] + return ChatPromptTemplate.from_messages(messages) + else: + return self.messages[index] + + def __len__(self) -> int: + """Get the length of the chat template.""" + return len(self.messages) + @property def _prompt_type(self) -> str: """Name of prompt type.""" @@ -635,14 +671,7 @@ def _create_template_from_message_type( def _convert_to_message( - message: Union[ - BaseMessagePromptTemplate, - BaseChatPromptTemplate, - BaseMessage, - Tuple[str, str], - Tuple[Type, str], - str, - ] + message: MessageLikeRepresentation, ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: """Instantiate a message from a variety of message formats. diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py index 517741151d..7b089c687c 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -282,6 +282,42 @@ def test_convert_to_message( assert _convert_to_message(args) == expected +def test_chat_prompt_template_indexing() -> None: + message1 = SystemMessage(content="foo") + message2 = HumanMessage(content="bar") + message3 = HumanMessage(content="baz") + template = ChatPromptTemplate.from_messages([message1, message2, message3]) + assert template[0] == message1 + assert template[1] == message2 + + # Slice starting from index 1 + slice_template = template[1:] + assert slice_template[0] == message2 + assert len(slice_template) == 2 + + +def test_chat_prompt_template_append_and_extend() -> None: + """Test append and extend methods of ChatPromptTemplate.""" + message1 = SystemMessage(content="foo") + message2 = HumanMessage(content="bar") + message3 = HumanMessage(content="baz") + template = ChatPromptTemplate.from_messages([message1]) + template.append(message2) + template.append(message3) + assert len(template) == 3 + template.extend([message2, message3]) + assert len(template) == 5 + assert template.messages == [ + message1, + message2, + message3, + message2, + message3, + ] + template.append(("system", "hello!")) + assert template[-1] == SystemMessagePromptTemplate.from_template("hello!") + + def test_convert_to_message_is_strict() -> None: """Verify that _convert_to_message is strict.""" with pytest.raises(ValueError):