Validate `input_variables` when using `jinja2` templates (#3140)

`langchain.prompts.PromptTemplate` and
`langchain.prompts.FewShotPromptTemplate` do not validate
`input_variables` when initialized as `jinja2` template.

```python
# Using langchain v0.0.144
template = """"\
Your variable: {{ foo }}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
"""

# Missing variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar"], 
                                 template_format="jinja2", 
                                 validate_template=True)

# Extra variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar", "foo", "extra", "thing"], 
                                 template_format="jinja2", 
                                 validate_template=True)
```
fix_agent_callbacks
engkheng 1 year ago committed by GitHub
parent 3e0c44bae8
commit dbbc340f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
"""Utilities for formatting strings."""
from string import Formatter
from typing import Any, Mapping, Sequence, Union
from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter):
@ -28,5 +28,11 @@ class StrictFormatter(Formatter):
)
return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter()

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
@ -26,11 +26,47 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
return Template(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
error_message = ""
if missing_variables:
error_message += f"Missing variables: {missing_variables} "
if extra_variables:
error_message += f"Extra variables: {extra_variables}"
if error_message:
raise KeyError(error_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"jinja2": jinja2_formatter,
}
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template(
template: str, template_format: str, input_variables: List[str]
@ -42,10 +78,9 @@ def check_valid_template(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
try:
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
formatter_func(template, **dummy_inputs)
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables)
except KeyError as e:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. "

@ -3,31 +3,18 @@ from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Set, Union
from typing import Any, Dict, List, Union
from pydantic import Extra, root_validator
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template,
)
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM.

@ -1,4 +1,6 @@
"""Test few shot prompt template."""
from typing import Dict, List, Tuple
import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate
@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
)
@pytest.fixture()
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}"
examples = [
{"word": "happy", "antonym": "sad"},
{"word": "tall", "antonym": "short"},
]
return (
PromptTemplate(
input_variables=["word", "antonym"],
template=example_template,
template_format="jinja2",
),
examples,
)
def test_suffix_only() -> None:
"""Test prompt works with just a suffix."""
suffix = "This is a {foo} test."
@ -174,3 +195,71 @@ def test_partial() -> None:
"Now you try to talk about party."
)
assert output == expected_output
def test_prompt_jinja2_functionality(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
prompt = FewShotPromptTemplate(
input_variables=["foo", "bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
output = prompt.format(foo="hello", bar="bye")
expected_output = (
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
)
assert output == expected_output
def test_prompt_jinja2_missing_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when input variables are not provided."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
# Test when missing in suffix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
# Test when missing in prefix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
def test_prompt_jinja2_extra_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when there are too many input variables."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)

@ -212,3 +212,33 @@ Your variable again: {{ foo }}
template_format="jinja2",
)
assert prompt == expected_prompt
def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test."
input_variables: list = []
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)

Loading…
Cancel
Save