Support inference of input_variables from jinja2 template (#3013)

`langchain.prompts.PromptTemplate` is unable to infer `input_variables`
from jinja2 template.

```python
# Using langchain v0.0.141
template_string = """\
Hello world
Your variable: {{ var }}
{# This will not get rendered #}

{% if verbose %}
Congrats! You just turned on verbose mode and got extra messages!
{% endif %}
"""

template = PromptTemplate.from_template(template_string, template_format="jinja2")
print(template.input_variables) # Output ['# This will not get rendered #', '% endif %', '% if verbose %']
```

---------

Co-authored-by: engkheng <ongengkheng929@example.com>
This commit is contained in:
engkheng 2023-04-18 11:31:03 +08:00 committed by GitHub
parent dac32c59e5
commit 19febc77d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 85 additions and 4 deletions

View File

@ -3,8 +3,9 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Set, Union
from jinja2 import Environment, meta
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.prompts.base import ( from langchain.prompts.base import (
@ -14,6 +15,13 @@ from langchain.prompts.base import (
) )
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
class PromptTemplate(StringPromptTemplate): class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM. """Schema to represent a prompt for an LLM.
@ -125,9 +133,15 @@ class PromptTemplate(StringPromptTemplate):
@classmethod @classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template.""" """Load a prompt template from a template."""
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
else:
input_variables = { input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None v for _, v, _, _ in Formatter().parse(template) if v is not None
} }
return cls( return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs input_variables=list(sorted(input_variables)), template=template, **kwargs
) )

View File

@ -145,3 +145,70 @@ def test_partial() -> None:
assert new_result == "This is a 3 test." assert new_result == "This is a 3 test."
result = prompt.format(foo="foo") result = prompt.format(foo="foo")
assert result == "This is a foo test." assert result == "This is a foo test."
def test_prompt_from_jinja2_template() -> None:
"""Test prompts can be constructed from a jinja2 template."""
# Empty input variable.
template = """Hello there
There is no variable here {
Will it get confused{ }?
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template, input_variables=[], template_format="jinja2"
)
assert prompt == expected_prompt
# Multiple input variables.
template = """\
Hello world
Your variable: {{ foo }}
{# This will not get rendered #}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
{% for i in foo_list %}
{{ i }}
{% endfor %}
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template,
input_variables=["bar", "foo", "foo_list"],
template_format="jinja2",
)
assert prompt == expected_prompt
# Multiple input variables with repeats.
template = """\
Hello world
Your variable: {{ foo }}
{# This will not get rendered #}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
{% for i in foo_list %}
{{ i }}
{% endfor %}
{% if bar %}
Your variable again: {{ foo }}
{% endif %}
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template,
input_variables=["bar", "foo", "foo_list"],
template_format="jinja2",
)
assert prompt == expected_prompt