Jinja2 validation changed to issue warnings rather than issuing exceptions. (#7161)

- Description: If their are missing or extra variables when validating
Jinja 2 template then a warning is issued rather than raising an
exception. This allows for better flexibility for the developer as
described in #7044. Also changed the relevant test so pytest is checking
for raised warnings rather than exceptions.
  - Issue: #7044 
  - Tag maintainer: @hwchase17, @baskaryan

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Mohammad Mohtashim 2023-07-05 23:04:29 +05:00 committed by GitHub
parent e288410e72
commit 7d92e9407b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 12 deletions

View File

@ -1,6 +1,7 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import warnings
from abc import ABC
from typing import Any, Callable, Dict, List, Set
@ -26,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Raise an exception if missing or extra variables are found.
Issues an warning if missing or extra variables are found.
Args:
template: The template string.
@ -37,15 +38,15 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
missing_variables = valid_variables - input_variables_set
extra_variables = input_variables_set - valid_variables
error_message = ""
warning_message = ""
if missing_variables:
error_message += f"Missing variables: {missing_variables} "
warning_message += f"Missing variables: {missing_variables} "
if extra_variables:
error_message += f"Extra variables: {extra_variables}"
warning_message += f"Extra variables: {extra_variables}"
if error_message:
raise KeyError(error_message.strip())
if warning_message:
warnings.warn(warning_message.strip())
def _get_jinja2_variables_from_template(template: str) -> Set[str]:

View File

@ -227,7 +227,7 @@ def test_prompt_jinja2_missing_input_variables(
suffix = "Ending with {{ bar }}"
# Test when missing in suffix
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
@ -237,7 +237,7 @@ def test_prompt_jinja2_missing_input_variables(
)
# Test when missing in prefix
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
@ -254,7 +254,7 @@ def test_prompt_jinja2_extra_input_variables(
"""Test error is raised when there are too many input variables."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,

View File

@ -218,7 +218,7 @@ def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test."
input_variables: list = []
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
@ -228,7 +228,7 @@ def test_prompt_jinja2_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {{ foo }} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)
@ -238,7 +238,7 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
"""Test error is raised when name of input variable is wrong."""
template = "This is a {{ foo }} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
)