diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index ce15d1add0..465e698516 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -261,63 +261,31 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): class ChatPromptTemplate(BaseChatPromptTemplate, ABC): - """Use to create flexible templated prompts for chat models. + """A prompt template for a chat models. + + Use to create flexible templated prompts for chat models. Examples: - Instantiation from role strings: - .. code-block:: python from langchain.prompts import ChatPromptTemplate - prompt_template = ChatPromptTemplate.from_role_strings( - [ - ('system', "You are a helpful bot. Your name is {bot_name}."), - ('human', "{user_input}") - ] - ) + template = ChatPromptTemplate.from_messages([ + ("system", "You are a helpful AI bot. Your name is {name}."), + ("human", "Hello, how are you doing?"), + ("assistant", "I'm doing well, thanks!"), + ("human", "{user_input}"), + ]) - prompt_template.format_messages( - bot_name="bobby", - user_input="Hello! What is your name?" - ) - - Instantiation from messages: - - This is useful if it's important to distinguish between messages that - are templates and messages that are already formatted. - - .. code-block:: python - - from langchain.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - SystemMessagePromptTemplate, - ) - - from langchain.schema import AIMessage - - prompt_template = ChatPromptTemplate.from_messages( - [ - SystemMessagePromptTemplate.from_template( - "You are a helpful bot. Your name is {bot_name}." - ), - AIMessage(content="Hello!"), # Already formatted message - HumanMessagePromptTemplate.from_template( - "{user_input}" - ), - ] - ) - - prompt_template.format_messages( - bot_name="bobby", - user_input="Hello! What is your name?" + messages = template.format_messages( + name="Bob", + user_input="What is your name?" ) """ input_variables: List[str] - """List of input variables.""" + """List of input variables in template messages. Used for validation.""" messages: List[ Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate] ] @@ -390,11 +358,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): def from_role_strings( cls, string_messages: List[Tuple[str, str]] ) -> ChatPromptTemplate: - """Create a class from a list of (role, template) tuples. - - The roles "human", "ai", and "system" are special and will be converted - to the appropriate message class. All other roles will be converted to a - generic ChatMessagePromptTemplate. + """Create a chat prompt template from a list of (role, template) tuples. Args: string_messages: list of (role, template) tuples. @@ -402,25 +366,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): Returns: a chat prompt template """ - messages: List[BaseMessagePromptTemplate] = [] - message: BaseMessagePromptTemplate - for role, template in string_messages: - if role == "human": - message = HumanMessagePromptTemplate.from_template(template) - elif role == "ai": - message = AIMessagePromptTemplate.from_template(template) - elif role == "system": - message = SystemMessagePromptTemplate.from_template(template) - else: - message = ChatMessagePromptTemplate.from_template(template, role=role) - messages.append(message) - return cls.from_messages(messages) + return cls( + messages=[ + ChatMessagePromptTemplate.from_template(template, role=role) + for role, template in string_messages + ] + ) @classmethod def from_strings( cls, string_messages: List[Tuple[Type[BaseMessagePromptTemplate], str]] ) -> ChatPromptTemplate: - """Create a class from a list of (role class, template) tuples. + """Create a chat prompt template from a list of (role class, template) tuples. Args: string_messages: list of (role class, template) tuples. @@ -428,29 +385,76 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): Returns: a chat prompt template """ - messages = [ - role(prompt=PromptTemplate.from_template(template)) - for role, template in string_messages - ] - return cls.from_messages(messages) + return cls.from_messages(string_messages) @classmethod def from_messages( - cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]] + cls, + messages: Sequence[ + Union[ + BaseMessagePromptTemplate, + BaseMessage, + Tuple[str, str], + Tuple[Type, str], + str, + ] + ], ) -> ChatPromptTemplate: - """Create a chat template from a sequence of messages. + """Create a chat prompt template from a variety of message formats. + + Examples: + + Instantiation from a list of role strings and templates: + + .. code-block:: python + + template = ChatPromptTemplate.from_messages([ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ]) + + Instantiation from mixed message formats: + + .. code-block:: python + + template = ChatPromptTemplate.from_messages([ + SystemMessage(content="hello"), + ("human", "Hello, how are you?"), + ]) + + Instantiation from a list message templates: + + + .. code-block:: python + + template = ChatPromptTemplate.from_messages([ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ]) + Args: - messages: sequence of templated or regular messages + messages: sequence of message representations. + A message can be represented using the following formats: + (1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of + (message type, template); e.g., ("human", "{user_input}"), + (4) 2-tuple of (message class, template), (4) a string which is + shorthand for ("human", template); e.g., "{user_input}" Returns: a chat prompt template """ + _messages = [_convert_to_message(message) for message in messages] + + # Automatically infer input variables from messages input_vars = set() - for message in messages: - if isinstance(message, BaseMessagePromptTemplate): - input_vars.update(message.input_variables) - return cls(input_variables=sorted(input_vars), messages=messages) + for _message in _messages: + if isinstance(_message, BaseMessagePromptTemplate): + input_vars.update(_message.input_variables) + + return cls(input_variables=sorted(input_vars), messages=_messages) def format(self, **kwargs: Any) -> str: """Format the chat template into a string. @@ -507,4 +511,77 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): Args: file_path: path to file. """ - raise NotImplementedError + raise NotImplementedError() + + +def _create_template_from_message_type( + message_type: str, template: str +) -> BaseMessagePromptTemplate: + """Create a message prompt template from a message type and template string. + + Args: + message_type: str the type of the message template (e.g., "human", "ai", etc.) + template: str the template string. + + Returns: + a message prompt template of the appropriate type. + """ + if message_type == "human": + message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template( + template + ) + elif message_type == "ai": + message = AIMessagePromptTemplate.from_template(template) + elif message_type == "system": + message = SystemMessagePromptTemplate.from_template(template) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human', 'ai', " + f"or 'system'." + ) + return message + + +def _convert_to_message( + message: Union[ + BaseMessagePromptTemplate, + BaseMessage, + Tuple[str, str], + Tuple[Type, str], + str, + ] +) -> Union[BaseMessage, BaseMessagePromptTemplate]: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - 2-tuple of (message class, template) + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, BaseMessagePromptTemplate): + _message: Union[BaseMessage, BaseMessagePromptTemplate] = message + elif isinstance(message, BaseMessage): + _message = message + elif isinstance(message, str): + _message = _create_template_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + if isinstance(message_type_str, str): + _message = _create_template_from_message_type(message_type_str, template) + else: + _message = message_type_str(prompt=PromptTemplate.from_template(template)) + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message diff --git a/libs/langchain/tests/unit_tests/prompts/test_chat.py b/libs/langchain/tests/unit_tests/prompts/test_chat.py index e1a4a047df..9769ddd7ec 100644 --- a/libs/langchain/tests/unit_tests/prompts/test_chat.py +++ b/libs/langchain/tests/unit_tests/prompts/test_chat.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Union +from typing import Any, List, Union import pytest @@ -13,6 +13,7 @@ from langchain.prompts.chat import ( ChatPromptValue, HumanMessagePromptTemplate, SystemMessagePromptTemplate, + _convert_to_message, ) from langchain.schema.messages import ( AIMessage, @@ -138,6 +139,33 @@ def test_chat_prompt_template_from_messages() -> None: assert len(chat_prompt_template.messages) == 4 +def test_chat_prompt_template_from_messages_using_role_strings() -> None: + """Test creating a chat prompt template from role string messages.""" + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI bot. Your name is {name}."), + ("human", "Hello, how are you doing?"), + ("ai", "I'm doing well, thanks!"), + ("human", "{user_input}"), + ] + ) + + messages = template.format_messages(name="Bob", user_input="What is your name?") + + assert messages == [ + SystemMessage( + content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={} + ), + HumanMessage( + content="Hello, how are you doing?", additional_kwargs={}, example=False + ), + AIMessage( + content="I'm doing well, thanks!", additional_kwargs={}, example=False + ), + HumanMessage(content="What is your name?", additional_kwargs={}, example=False), + ] + + def test_chat_prompt_template_with_messages() -> None: messages: List[ Union[BaseMessagePromptTemplate, BaseMessage] @@ -205,7 +233,7 @@ def test_chat_from_role_strings() -> None: template = ChatPromptTemplate.from_role_strings( [ ("system", "You are a bot."), - ("ai", "hello!"), + ("assistant", "hello!"), ("human", "{question}"), ("other", "{quack}"), ] @@ -213,8 +241,50 @@ def test_chat_from_role_strings() -> None: messages = template.format_messages(question="How are you?", quack="duck") assert messages == [ - SystemMessage(content="You are a bot."), - AIMessage(content="hello!"), - HumanMessage(content="How are you?"), + ChatMessage(content="You are a bot.", role="system"), + ChatMessage(content="hello!", role="assistant"), + ChatMessage(content="How are you?", role="human"), ChatMessage(content="duck", role="other"), ] + + +@pytest.mark.parametrize( + "args,expected", + [ + ( + ("human", "{question}"), + HumanMessagePromptTemplate( + prompt=PromptTemplate.from_template("{question}") + ), + ), + ( + "{question}", + HumanMessagePromptTemplate( + prompt=PromptTemplate.from_template("{question}") + ), + ), + (HumanMessage(content="question"), HumanMessage(content="question")), + ( + HumanMessagePromptTemplate( + prompt=PromptTemplate.from_template("{question}") + ), + HumanMessagePromptTemplate( + prompt=PromptTemplate.from_template("{question}") + ), + ), + ], +) +def test_convert_to_message( + args: Any, expected: Union[BaseMessage, BaseMessagePromptTemplate] +) -> None: + """Test convert to message.""" + assert _convert_to_message(args) == expected + + +def test_convert_to_message_is_strict() -> None: + """Verify that _convert_to_message is strict.""" + with pytest.raises(ValueError): + # meow does not correspond to a valid message type. + # this test is here to ensure that functionality to interpret `meow` + # as a role is NOT added. + _convert_to_message(("meow", "question"))