Validate input_variables when using jinja2 templates (#3140)

`langchain.prompts.PromptTemplate` and
`langchain.prompts.FewShotPromptTemplate` do not validate
`input_variables` when initialized as `jinja2` template.

```python
# Using langchain v0.0.144
template = """"\
Your variable: {{ foo }}
{% if bar %}
You just set bar boolean variable to true
{% endif %}
"""

# Missing variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar"], 
                                 template_format="jinja2", 
                                 validate_template=True)

# Extra variable, should raise ValueError
prompt_template = PromptTemplate(template=template, 
                                 input_variables=["bar", "foo", "extra", "thing"], 
                                 template_format="jinja2", 
                                 validate_template=True)
```
This commit is contained in:
engkheng 2023-04-20 07:18:32 +08:00 committed by GitHub
parent 3e0c44bae8
commit dbbc340f25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 167 additions and 20 deletions

View File

@ -1,6 +1,6 @@
"""Utilities for formatting strings.""" """Utilities for formatting strings."""
from string import Formatter from string import Formatter
from typing import Any, Mapping, Sequence, Union from typing import Any, List, Mapping, Sequence, Union
class StrictFormatter(Formatter): class StrictFormatter(Formatter):
@ -28,5 +28,11 @@ class StrictFormatter(Formatter):
) )
return super().vformat(format_string, args, kwargs) return super().vformat(format_string, args, kwargs)
def validate_input_variables(
self, format_string: str, input_variables: List[str]
) -> None:
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
super().format(format_string, **dummy_inputs)
formatter = StrictFormatter() formatter = StrictFormatter()

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml import yaml
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
@ -26,11 +26,47 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
return Template(template).render(**kwargs) return Template(template).render(**kwargs)
def validate_jinja2(template: str, input_variables: List[str]) -> None:
input_variables_set = set(input_variables)
valid_variables = _get_jinja2_variables_from_template(template)
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
error_message = ""
if missing_variables:
error_message += f"Missing variables: {missing_variables} "
if extra_variables:
error_message += f"Extra variables: {extra_variables}"
if error_message:
raise KeyError(error_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format, "f-string": formatter.format,
"jinja2": jinja2_formatter, "jinja2": jinja2_formatter,
} }
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
"f-string": formatter.validate_input_variables,
"jinja2": validate_jinja2,
}
def check_valid_template( def check_valid_template(
template: str, template_format: str, input_variables: List[str] template: str, template_format: str, input_variables: List[str]
@ -42,10 +78,9 @@ def check_valid_template(
f"Invalid template format. Got `{template_format}`;" f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}" f" should be one of {valid_formats}"
) )
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
try: try:
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
formatter_func(template, **dummy_inputs) validator_func(template, input_variables)
except KeyError as e: except KeyError as e:
raise ValueError( raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. " "Invalid prompt schema; check for mismatched or missing input parameters. "

View File

@ -3,31 +3,18 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Dict, List, Set, Union from typing import Any, Dict, List, Union
from pydantic import Extra, root_validator from pydantic import Extra, root_validator
from langchain.prompts.base import ( from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate, StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template, check_valid_template,
) )
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
try:
from jinja2 import Environment, meta
except ImportError:
raise ImportError(
"jinja2 not installed, which is needed to use the jinja2_formatter. "
"Please install it with `pip install jinja2`."
)
env = Environment()
ast = env.parse(template)
variables = meta.find_undeclared_variables(ast)
return variables
class PromptTemplate(StringPromptTemplate): class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM. """Schema to represent a prompt for an LLM.

View File

@ -1,4 +1,6 @@
"""Test few shot prompt template.""" """Test few shot prompt template."""
from typing import Dict, List, Tuple
import pytest import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
) )
@pytest.fixture()
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}"
examples = [
{"word": "happy", "antonym": "sad"},
{"word": "tall", "antonym": "short"},
]
return (
PromptTemplate(
input_variables=["word", "antonym"],
template=example_template,
template_format="jinja2",
),
examples,
)
def test_suffix_only() -> None: def test_suffix_only() -> None:
"""Test prompt works with just a suffix.""" """Test prompt works with just a suffix."""
suffix = "This is a {foo} test." suffix = "This is a {foo} test."
@ -174,3 +195,71 @@ def test_partial() -> None:
"Now you try to talk about party." "Now you try to talk about party."
) )
assert output == expected_output assert output == expected_output
def test_prompt_jinja2_functionality(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
prompt = FewShotPromptTemplate(
input_variables=["foo", "bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
output = prompt.format(foo="hello", bar="bye")
expected_output = (
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
)
assert output == expected_output
def test_prompt_jinja2_missing_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when input variables are not provided."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
# Test when missing in suffix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
# Test when missing in prefix
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
def test_prompt_jinja2_extra_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when there are too many input variables."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)

View File

@ -212,3 +212,33 @@ Your variable again: {{ foo }}
template_format="jinja2", template_format="jinja2",
) )
assert prompt == expected_prompt assert prompt == expected_prompt
def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test."
input_variables: list = []
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)