forked from Archives/langchain
Validate input_variables
when using jinja2
templates (#3140)
`langchain.prompts.PromptTemplate` and `langchain.prompts.FewShotPromptTemplate` do not validate `input_variables` when initialized as `jinja2` template. ```python # Using langchain v0.0.144 template = """"\ Your variable: {{ foo }} {% if bar %} You just set bar boolean variable to true {% endif %} """ # Missing variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar"], template_format="jinja2", validate_template=True) # Extra variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar", "foo", "extra", "thing"], template_format="jinja2", validate_template=True) ```
This commit is contained in:
parent
3e0c44bae8
commit
dbbc340f25
@ -1,6 +1,6 @@
|
||||
"""Utilities for formatting strings."""
|
||||
from string import Formatter
|
||||
from typing import Any, Mapping, Sequence, Union
|
||||
from typing import Any, List, Mapping, Sequence, Union
|
||||
|
||||
|
||||
class StrictFormatter(Formatter):
|
||||
@ -28,5 +28,11 @@ class StrictFormatter(Formatter):
|
||||
)
|
||||
return super().vformat(format_string, args, kwargs)
|
||||
|
||||
def validate_input_variables(
|
||||
self, format_string: str, input_variables: List[str]
|
||||
) -> None:
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
super().format(format_string, **dummy_inputs)
|
||||
|
||||
|
||||
formatter = StrictFormatter()
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
@ -26,11 +26,47 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
return Template(template).render(**kwargs)
|
||||
|
||||
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
input_variables_set = set(input_variables)
|
||||
valid_variables = _get_jinja2_variables_from_template(template)
|
||||
missing_variables = valid_variables - input_variables_set
|
||||
extra_variables = input_variables_set - valid_variables
|
||||
|
||||
error_message = ""
|
||||
if missing_variables:
|
||||
error_message += f"Missing variables: {missing_variables} "
|
||||
|
||||
if extra_variables:
|
||||
error_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if error_message:
|
||||
raise KeyError(error_message.strip())
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.format,
|
||||
"jinja2": jinja2_formatter,
|
||||
}
|
||||
|
||||
DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = {
|
||||
"f-string": formatter.validate_input_variables,
|
||||
"jinja2": validate_jinja2,
|
||||
}
|
||||
|
||||
|
||||
def check_valid_template(
|
||||
template: str, template_format: str, input_variables: List[str]
|
||||
@ -42,10 +78,9 @@ def check_valid_template(
|
||||
f"Invalid template format. Got `{template_format}`;"
|
||||
f" should be one of {valid_formats}"
|
||||
)
|
||||
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
||||
try:
|
||||
formatter_func = DEFAULT_FORMATTER_MAPPING[template_format]
|
||||
formatter_func(template, **dummy_inputs)
|
||||
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
|
||||
validator_func(template, input_variables)
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"Invalid prompt schema; check for mismatched or missing input parameters. "
|
||||
|
@ -3,31 +3,18 @@ from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from typing import Any, Dict, List, Set, Union
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
_get_jinja2_variables_from_template,
|
||||
check_valid_template,
|
||||
)
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
try:
|
||||
from jinja2 import Environment, meta
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"jinja2 not installed, which is needed to use the jinja2_formatter. "
|
||||
"Please install it with `pip install jinja2`."
|
||||
)
|
||||
env = Environment()
|
||||
ast = env.parse(template)
|
||||
variables = meta.find_undeclared_variables(ast)
|
||||
return variables
|
||||
|
||||
|
||||
class PromptTemplate(StringPromptTemplate):
|
||||
"""Schema to represent a prompt for an LLM.
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Test few shot prompt template."""
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate(
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
|
||||
example_template = "{{ word }}: {{ antonym }}"
|
||||
|
||||
examples = [
|
||||
{"word": "happy", "antonym": "sad"},
|
||||
{"word": "tall", "antonym": "short"},
|
||||
]
|
||||
|
||||
return (
|
||||
PromptTemplate(
|
||||
input_variables=["word", "antonym"],
|
||||
template=example_template,
|
||||
template_format="jinja2",
|
||||
),
|
||||
examples,
|
||||
)
|
||||
|
||||
|
||||
def test_suffix_only() -> None:
|
||||
"""Test prompt works with just a suffix."""
|
||||
suffix = "This is a {foo} test."
|
||||
@ -174,3 +195,71 @@ def test_partial() -> None:
|
||||
"Now you try to talk about party."
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompt_jinja2_functionality(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
|
||||
prompt = FewShotPromptTemplate(
|
||||
input_variables=["foo", "bar"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
output = prompt.format(foo="hello", bar="bye")
|
||||
expected_output = (
|
||||
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
|
||||
)
|
||||
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompt_jinja2_missing_input_variables(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
|
||||
# Test when missing in suffix
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=suffix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
# Test when missing in prefix
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_extra_input_variables(
|
||||
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
|
||||
) -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar", "foo", "extra", "thing"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
@ -212,3 +212,33 @@ Your variable again: {{ foo }}
|
||||
template_format="jinja2",
|
||||
)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_prompt_jinja2_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
||||
|
||||
def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user