Support inference of `input_variables` from `jinja2` template (#3013)

`langchain.prompts.PromptTemplate` is unable to infer `input_variables`
from jinja2 template.

```python
# Using langchain v0.0.141
template_string = """\
Hello world
Your variable: {{ var }}
{# This will not get rendered #}

{% if verbose %}
Congrats! You just turned on verbose mode and got extra messages!
{% endif %}
"""

template = PromptTemplate.from_template(template_string, template_format="jinja2")
print(template.input_variables) # Output ['# This will not get rendered #', '% endif %', '% if verbose %']
```

---------

Co-authored-by: engkheng <ongengkheng929@example.com>
fix_agent_callbacks
engkheng 1 year ago committed by GitHub
parent dac32c59e5
commit 19febc77d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,8 +3,9 @@ from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Set, Union
from jinja2 import Environment, meta
from pydantic import Extra, root_validator
from langchain.prompts.base import (
@ -14,6 +15,13 @@ from langchain.prompts.base import (
)
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM.
@ -125,9 +133,15 @@ class PromptTemplate(StringPromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
else:
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)

@ -145,3 +145,70 @@ def test_partial() -> None:
assert new_result == "This is a 3 test."
result = prompt.format(foo="foo")
assert result == "This is a foo test."
def test_prompt_from_jinja2_template() -> None:
"""Test prompts can be constructed from a jinja2 template."""
# Empty input variable.
template = """Hello there
There is no variable here {
Will it get confused{ }?
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template, input_variables=[], template_format="jinja2"
)
assert prompt == expected_prompt
# Multiple input variables.
template = """\
Hello world
Your variable: {{ foo }}
{# This will not get rendered #}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
{% for i in foo_list %}
{{ i }}
{% endfor %}
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template,
input_variables=["bar", "foo", "foo_list"],
template_format="jinja2",
)
assert prompt == expected_prompt
# Multiple input variables with repeats.
template = """\
Hello world
Your variable: {{ foo }}
{# This will not get rendered #}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
{% for i in foo_list %}
{{ i }}
{% endfor %}
{% if bar %}
Your variable again: {{ foo }}
{% endif %}
"""
prompt = PromptTemplate.from_template(template, template_format="jinja2")
expected_prompt = PromptTemplate(
template=template,
input_variables=["bar", "foo", "foo_list"],
template_format="jinja2",
)
assert prompt == expected_prompt

Loading…
Cancel
Save