forked from Archives/langchain
Validate input_variables
when using jinja2
templates (#3140)
`langchain.prompts.PromptTemplate` and `langchain.prompts.FewShotPromptTemplate` do not validate `input_variables` when initialized as `jinja2` template. ```python # Using langchain v0.0.144 template = """"\ Your variable: {{ foo }} {% if bar %} You just set bar boolean variable to true {% endif %} """ # Missing variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar"], template_format="jinja2", validate_template=True) # Extra variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar", "foo", "extra", "thing"], template_format="jinja2", validate_template=True) ```
This commit is contained in:
parent
3e0c44bae8
commit
dbbc340f25
@ -1,6 +1,6 @@
|
|||||||
"""Utilities for formatting strings."""
|
"""Utilities for formatting strings."""
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Mapping, Sequence, Union
|
from typing import Any, List, Mapping, Sequence, Union
|
||||||
|
|
||||||
|
|
||||||
class StrictFormatter(Formatter):
|
class StrictFormatter(Formatter):
|
||||||
@ -28,5 +28,11 @@ class StrictFormatter(Formatter):
|
|||||||
)
|
)
|
||||||
return super().vformat(format_string, args, kwargs)
|
return super().vformat(format_string, args, kwargs)
|
||||||
|
|
||||||
|
def validate_input_variables(
|
||||||
|
self, format_string: str, input_variables: List[str]
|
||||||
|
) -> None:
|
||||||
|
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||||
|
super().format(format_string, **dummy_inputs)
|
||||||
|
|
||||||
|
|
||||||
formatter = StrictFormatter()
|
formatter = StrictFormatter()
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
@ -26,11 +26,47 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
|||||||
return Template(template).render(**kwargs)
|
return Template(template).render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||||
|
input_variables_set = set(input_variables)
|
||||||
|
valid_variables = _get_jinja2_variables_from_template(template)
|
||||||
|
missing_variables = valid_variables - input_variables_set
|
||||||
|
extra_variables = input_variables_set - valid_variables
|
||||||
|
|
||||||
|
error_message = ""
|
||||||
|
if missing_variables:
|
||||||
|
error_message += f"Missing variables: {missing_variables} "
|
||||||
|
|
||||||
|
if extra_variables:
|
||||||
|
error_message += f"Extra variables: {extra_variables}"
|
||||||
|
|
||||||
|
if error_message:
|
||||||
|
raise KeyError(error_message.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||||
|
try:
|
||||||
|
from jinja2 import Environment, meta
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||||
|
"Please install it with `pip install jinja2`."
|
||||||
|
)
|
||||||
|
env = Environment()
|
||||||
|
ast = env.parse(template)
|
||||||
|
variables = meta.find_undeclared_variables(ast)
|
||||||
|
return variables
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||||
"f-string": formatter.format,
|
"f-string": formatter.format,
|
||||||
"jinja2": jinja2_formatter,
|
"jinja2": jinja2_formatter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||||
|
"f-string": formatter.validate_input_variables,
|
||||||
|
"jinja2": validate_jinja2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_valid_template(
|
def check_valid_template(
|
||||||
template: str, template_format: str, input_variables: List[str]
|
template: str, template_format: str, input_variables: List[str]
|
||||||
@ -42,10 +78,9 @@ def check_valid_template(
|
|||||||
f"Invalid template format. Got `{template_format}`;"
|
f"Invalid template format. Got `{template_format}`;"
|
||||||
f" should be one of {valid_formats}"
|
f" should be one of {valid_formats}"
|
||||||
)
|
)
|
||||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
|
||||||
try:
|
try:
|
||||||
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
|
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
||||||
formatter_func(template, **dummy_inputs)
|
validator_func(template, input_variables)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||||
|
@ -3,31 +3,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Dict, List, Set, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from pydantic import Extra, root_validator
|
from pydantic import Extra, root_validator
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
|
_get_jinja2_variables_from_template,
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
|
||||||
try:
|
|
||||||
from jinja2 import Environment, meta
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
|
||||||
"Please install it with `pip install jinja2`."
|
|
||||||
)
|
|
||||||
env = Environment()
|
|
||||||
ast = env.parse(template)
|
|
||||||
variables = meta.find_undeclared_variables(ast)
|
|
||||||
return variables
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate(StringPromptTemplate):
|
class PromptTemplate(StringPromptTemplate):
|
||||||
"""Schema to represent a prompt for an LLM.
|
"""Schema to represent a prompt for an LLM.
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Test few shot prompt template."""
|
"""Test few shot prompt template."""
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
|
||||||
|
example_template = "{{ word }}: {{ antonym }}"
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{"word": "happy", "antonym": "sad"},
|
||||||
|
{"word": "tall", "antonym": "short"},
|
||||||
|
]
|
||||||
|
|
||||||
|
return (
|
||||||
|
PromptTemplate(
|
||||||
|
input_variables=["word", "antonym"],
|
||||||
|
template=example_template,
|
||||||
|
template_format="jinja2",
|
||||||
|
),
|
||||||
|
examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_suffix_only() -> None:
|
def test_suffix_only() -> None:
|
||||||
"""Test prompt works with just a suffix."""
|
"""Test prompt works with just a suffix."""
|
||||||
suffix = "This is a {foo} test."
|
suffix = "This is a {foo} test."
|
||||||
@ -174,3 +195,71 @@ def test_partial() -> None:
|
|||||||
"Now you try to talk about party."
|
"Now you try to talk about party."
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_functionality(
|
||||||
|
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||||
|
) -> None:
|
||||||
|
prefix = "Starting with {{ foo }}"
|
||||||
|
suffix = "Ending with {{ bar }}"
|
||||||
|
|
||||||
|
prompt = FewShotPromptTemplate(
|
||||||
|
input_variables=["foo", "bar"],
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
output = prompt.format(foo="hello", bar="bye")
|
||||||
|
expected_output = (
|
||||||
|
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_missing_input_variables(
|
||||||
|
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||||
|
) -> None:
|
||||||
|
"""Test error is raised when input variables are not provided."""
|
||||||
|
prefix = "Starting with {{ foo }}"
|
||||||
|
suffix = "Ending with {{ bar }}"
|
||||||
|
|
||||||
|
# Test when missing in suffix
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FewShotPromptTemplate(
|
||||||
|
input_variables=[],
|
||||||
|
suffix=suffix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test when missing in prefix
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FewShotPromptTemplate(
|
||||||
|
input_variables=["bar"],
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_extra_input_variables(
|
||||||
|
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||||
|
) -> None:
|
||||||
|
"""Test error is raised when there are too many input variables."""
|
||||||
|
prefix = "Starting with {{ foo }}"
|
||||||
|
suffix = "Ending with {{ bar }}"
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FewShotPromptTemplate(
|
||||||
|
input_variables=["bar", "foo", "extra", "thing"],
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
)
|
||||||
|
@ -212,3 +212,33 @@ Your variable again: {{ foo }}
|
|||||||
template_format="jinja2",
|
template_format="jinja2",
|
||||||
)
|
)
|
||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_missing_input_variables() -> None:
|
||||||
|
"""Test error is raised when input variables are not provided."""
|
||||||
|
template = "This is a {{ foo }} test."
|
||||||
|
input_variables: list = []
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_extra_input_variables() -> None:
|
||||||
|
"""Test error is raised when there are too many input variables."""
|
||||||
|
template = "This is a {{ foo }} test."
|
||||||
|
input_variables = ["foo", "bar"]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||||
|
"""Test error is raised when name of input variable is wrong."""
|
||||||
|
template = "This is a {{ foo }} test."
|
||||||
|
input_variables = ["bar"]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user