diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index 0d77aa62..18a18514 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -3,8 +3,9 @@ from __future__ import annotations from pathlib import Path from string import Formatter -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Set, Union +from jinja2 import Environment, meta from pydantic import Extra, root_validator from langchain.prompts.base import ( @@ -14,6 +15,13 @@ from langchain.prompts.base import ( ) +def _get_jinja2_variables_from_template(template: str) -> Set[str]: + env = Environment() + ast = env.parse(template) + variables = meta.find_undeclared_variables(ast) + return variables + + class PromptTemplate(StringPromptTemplate): """Schema to represent a prompt for an LLM. @@ -125,9 +133,15 @@ class PromptTemplate(StringPromptTemplate): @classmethod def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: """Load a prompt template from a template.""" - input_variables = { - v for _, v, _, _ in Formatter().parse(template) if v is not None - } + if "template_format" in kwargs and kwargs["template_format"] == "jinja2": + # Get the variables for the template + input_variables = _get_jinja2_variables_from_template(template) + + else: + input_variables = { + v for _, v, _, _ in Formatter().parse(template) if v is not None + } + return cls( input_variables=list(sorted(input_variables)), template=template, **kwargs ) diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index d7f38874..db55fb11 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -145,3 +145,70 @@ def test_partial() -> None: assert new_result == "This is a 3 test." result = prompt.format(foo="foo") assert result == "This is a foo test." + + +def test_prompt_from_jinja2_template() -> None: + """Test prompts can be constructed from a jinja2 template.""" + # Empty input variable. + template = """Hello there +There is no variable here { +Will it get confused{ }? + """ + prompt = PromptTemplate.from_template(template, template_format="jinja2") + expected_prompt = PromptTemplate( + template=template, input_variables=[], template_format="jinja2" + ) + assert prompt == expected_prompt + + # Multiple input variables. + template = """\ +Hello world + +Your variable: {{ foo }} + +{# This will not get rendered #} + +{% if bar %} +You just set bar boolean variable to true +{% endif %} + +{% for i in foo_list %} +{{ i }} +{% endfor %} +""" + prompt = PromptTemplate.from_template(template, template_format="jinja2") + expected_prompt = PromptTemplate( + template=template, + input_variables=["bar", "foo", "foo_list"], + template_format="jinja2", + ) + + assert prompt == expected_prompt + + # Multiple input variables with repeats. + template = """\ +Hello world + +Your variable: {{ foo }} + +{# This will not get rendered #} + +{% if bar %} +You just set bar boolean variable to true +{% endif %} + +{% for i in foo_list %} +{{ i }} +{% endfor %} + +{% if bar %} +Your variable again: {{ foo }} +{% endif %} +""" + prompt = PromptTemplate.from_template(template, template_format="jinja2") + expected_prompt = PromptTemplate( + template=template, + input_variables=["bar", "foo", "foo_list"], + template_format="jinja2", + ) + assert prompt == expected_prompt