diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index 5081408c..eb39b52e 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -120,7 +120,7 @@ class ChatPromptValue(PromptValue): class ChatPromptTemplate(BasePromptTemplate, ABC): input_variables: List[str] - messages: List[BaseMessagePromptTemplate] + messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] @classmethod def from_role_strings( @@ -146,11 +146,12 @@ class ChatPromptTemplate(BasePromptTemplate, ABC): @classmethod def from_messages( - cls, messages: Sequence[BaseMessagePromptTemplate] + cls, messages: Sequence[Union[BaseMessagePromptTemplate, BaseMessage]] ) -> ChatPromptTemplate: input_vars = set() for message in messages: - input_vars.update(message.input_variables) + if isinstance(message, BaseMessagePromptTemplate): + input_vars.update(message.input_variables) return cls(input_variables=list(input_vars), messages=messages) def format(self, **kwargs: Any) -> str: @@ -159,11 +160,18 @@ class ChatPromptTemplate(BasePromptTemplate, ABC): def format_prompt(self, **kwargs: Any) -> PromptValue: result = [] for message_template in self.messages: - rel_params = { - k: v for k, v in kwargs.items() if k in message_template.input_variables - } - message = message_template.format_messages(**rel_params) - result.extend(message) + if isinstance(message_template, BaseMessage): + result.extend([message_template]) + elif isinstance(message_template, BaseMessagePromptTemplate): + rel_params = { + k: v + for k, v in kwargs.items() + if k in message_template.input_variables + } + message = message_template.format_messages(**rel_params) + result.extend(message) + else: + raise ValueError(f"Unexpected input: {message_template}") return ChatPromptValue(messages=result) def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate: diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index 9a49a155..04174b15 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -10,6 +10,7 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) +from langchain.schema import HumanMessage def create_messages() -> List[BaseMessagePromptTemplate]: @@ -89,3 +90,17 @@ def test_chat_prompt_template_from_messages() -> None: ["context", "foo", "bar"] ) assert len(chat_prompt_template.messages) == 4 + + +def test_chat_prompt_template_with_messages() -> None: + messages = create_messages() + [HumanMessage(content="foo")] + chat_prompt_template = ChatPromptTemplate.from_messages(messages) + assert sorted(chat_prompt_template.input_variables) == sorted( + ["context", "foo", "bar"] + ) + assert len(chat_prompt_template.messages) == 5 + prompt_value = chat_prompt_template.format_prompt( + context="see", foo="this", bar="magic" + ) + prompt_value_messages = prompt_value.to_messages() + assert prompt_value_messages[-1] == HumanMessage(content="foo")