From 5b48ab8db3755fe91e6d2839037024f7da8f819c Mon Sep 17 00:00:00 2001 From: Edmar Ferreira Date: Sun, 20 Nov 2022 12:20:15 -0300 Subject: [PATCH] add mako template --- langchain/formatting.py | 6 ++++++ langchain/prompts/base.py | 1 + langchain/prompts/prompt.py | 22 ++++++++++++++++++++++ tests/unit_tests/data/mako_prompt.txt | 1 + tests/unit_tests/prompts/test_prompt.py | 9 +++++++++ 5 files changed, 39 insertions(+) create mode 100644 tests/unit_tests/data/mako_prompt.txt diff --git a/langchain/formatting.py b/langchain/formatting.py index 61c7c116..e2d86047 100644 --- a/langchain/formatting.py +++ b/langchain/formatting.py @@ -1,6 +1,7 @@ """Utilities for formatting strings.""" from string import Formatter from typing import Any, Mapping, Sequence, Union +from mako.template import Template class StrictFormatter(Formatter): @@ -28,5 +29,10 @@ class StrictFormatter(Formatter): ) return super().vformat(format_string, args, kwargs) + def mako_format(self, format_string: str, **kwargs: Any) -> str: + """Format a string using mako.""" + template = Template(format_string) + return template.render(**kwargs) + formatter = StrictFormatter() diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 44cdad9f..e1e56dde 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -6,6 +6,7 @@ from langchain.formatting import formatter DEFAULT_FORMATTER_MAPPING = { "f-string": formatter.format, + "mako": formatter.mako_format, } diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 6eccaaa3..a89cd168 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List from pydantic import BaseModel, Extra, root_validator + from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, BasePromptTemplate, @@ -106,6 +107,27 @@ class PromptTemplate(BaseModel, BasePromptTemplate): template = f.read() return cls(input_variables=input_variables, template=template) + @classmethod + def from_mako_template( + cls, template_file: str, input_variables: List[str] + ) -> "PromptTemplate": + """Load a prompt from a mako template file. + + Args: + template_file: The path to the file containing the prompt template. + input_variables: A list of variable names the final prompt template + will expect. + Returns: + The prompt loaded from the mako template file. + """ + with open(template_file, "r") as f: + template = f.read() + return cls( + input_variables=input_variables, + template=template, + template_format="mako", + ) + # For backwards compatibility. Prompt = PromptTemplate diff --git a/tests/unit_tests/data/mako_prompt.txt b/tests/unit_tests/data/mako_prompt.txt new file mode 100644 index 00000000..25bb9f75 --- /dev/null +++ b/tests/unit_tests/data/mako_prompt.txt @@ -0,0 +1 @@ +This is a ${foo} test. \ No newline at end of file diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index cec597e0..0b01d3ea 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -87,3 +87,12 @@ def test_prompt_from_file() -> None: input_variables = ["question"] prompt = PromptTemplate.from_file(template_file, input_variables) assert prompt.template == "Question: {question}\nAnswer:" + + +def test_mako_template() -> None: + """Test mako template can be used.""" + template_file = "tests/unit_tests/data/mako_prompt.txt" + input_variables = ["foo"] + prompt = PromptTemplate.from_mako_template(template_file, input_variables) + assert prompt.template == "This is a ${foo} test." + assert prompt.format(foo="bar") == "This is a bar test."