mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Support inference of input_variables
from jinja2
template (#3013)
`langchain.prompts.PromptTemplate` is unable to infer `input_variables` from jinja2 template. ```python # Using langchain v0.0.141 template_string = """\ Hello world Your variable: {{ var }} {# This will not get rendered #} {% if verbose %} Congrats! You just turned on verbose mode and got extra messages! {% endif %} """ template = PromptTemplate.from_template(template_string, template_format="jinja2") print(template.input_variables) # Output ['# This will not get rendered #', '% endif %', '% if verbose %'] ``` --------- Co-authored-by: engkheng <ongengkheng929@example.com>
This commit is contained in:
parent
dac32c59e5
commit
19febc77d6
@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Set, Union
|
||||||
|
|
||||||
|
from jinja2 import Environment, meta
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
@ -14,6 +15,13 @@ from langchain.prompts.base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||||
|
env = Environment()
|
||||||
|
ast = env.parse(template)
|
||||||
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate(StringPromptTemplate):
|
class PromptTemplate(StringPromptTemplate):
|
||||||
"""Schema to represent a prompt for an LLM.
|
"""Schema to represent a prompt for an LLM.
|
||||||
|
|
||||||
@ -125,9 +133,15 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||||
"""Load a prompt template from a template."""
|
"""Load a prompt template from a template."""
|
||||||
|
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
|
||||||
|
# Get the variables for the template
|
||||||
|
input_variables = _get_jinja2_variables_from_template(template)
|
||||||
|
|
||||||
|
else:
|
||||||
input_variables = {
|
input_variables = {
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||||
)
|
)
|
||||||
|
@ -145,3 +145,70 @@ def test_partial() -> None:
|
|||||||
assert new_result == "This is a 3 test."
|
assert new_result == "This is a 3 test."
|
||||||
result = prompt.format(foo="foo")
|
result = prompt.format(foo="foo")
|
||||||
assert result == "This is a foo test."
|
assert result == "This is a foo test."
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_from_jinja2_template() -> None:
|
||||||
|
"""Test prompts can be constructed from a jinja2 template."""
|
||||||
|
# Empty input variable.
|
||||||
|
template = """Hello there
|
||||||
|
There is no variable here {
|
||||||
|
Will it get confused{ }?
|
||||||
|
"""
|
||||||
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
|
expected_prompt = PromptTemplate(
|
||||||
|
template=template, input_variables=[], template_format="jinja2"
|
||||||
|
)
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
# Multiple input variables.
|
||||||
|
template = """\
|
||||||
|
Hello world
|
||||||
|
|
||||||
|
Your variable: {{ foo }}
|
||||||
|
|
||||||
|
{# This will not get rendered #}
|
||||||
|
|
||||||
|
{% if bar %}
|
||||||
|
You just set bar boolean variable to true
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for i in foo_list %}
|
||||||
|
{{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
"""
|
||||||
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
|
expected_prompt = PromptTemplate(
|
||||||
|
template=template,
|
||||||
|
input_variables=["bar", "foo", "foo_list"],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
# Multiple input variables with repeats.
|
||||||
|
template = """\
|
||||||
|
Hello world
|
||||||
|
|
||||||
|
Your variable: {{ foo }}
|
||||||
|
|
||||||
|
{# This will not get rendered #}
|
||||||
|
|
||||||
|
{% if bar %}
|
||||||
|
You just set bar boolean variable to true
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for i in foo_list %}
|
||||||
|
{{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if bar %}
|
||||||
|
Your variable again: {{ foo }}
|
||||||
|
{% endif %}
|
||||||
|
"""
|
||||||
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
|
expected_prompt = PromptTemplate(
|
||||||
|
template=template,
|
||||||
|
input_variables=["bar", "foo", "foo_list"],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
Loading…
Reference in New Issue
Block a user