diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index af1fbef4..9df7cece 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -74,6 +74,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): prompt = PromptTemplate.from_template(template) return cls(prompt=prompt, **kwargs) + @classmethod + def from_template_file( + cls: Type[MessagePromptTemplateT], + template_file: Union[str, Path], + input_variables: List[str], + **kwargs: Any, + ) -> MessagePromptTemplateT: + prompt = PromptTemplate.from_file(template_file, input_variables) + return cls(prompt=prompt, **kwargs) + @abstractmethod def format(self, **kwargs: Any) -> BaseMessage: """To a BaseMessage.""" diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index 87f64c95..a9844a78 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import List from langchain.prompts import PromptTemplate @@ -80,6 +81,21 @@ def test_create_chat_prompt_template_from_template_partial() -> None: assert output_prompt.prompt == expected_prompt +def test_message_prompt_template_from_template_file() -> None: + expected = ChatMessagePromptTemplate( + prompt=PromptTemplate( + template="Question: {question}\nAnswer:", input_variables=["question"] + ), + role="human", + ) + actual = ChatMessagePromptTemplate.from_template_file( + Path(__file__).parent.parent / "data" / "prompt_file.txt", + ["question"], + role="human", + ) + assert expected == actual + + def test_chat_prompt_template() -> None: """Test chat prompt template.""" prompt_template = create_chat_prompt_template()