mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Jinja2 validation changed to issue warnings rather than issuing exceptions. (#7161)
- Description: If their are missing or extra variables when validating Jinja 2 template then a warning is issued rather than raising an exception. This allows for better flexibility for the developer as described in #7044. Also changed the relevant test so pytest is checking for raised warnings rather than exceptions. - Issue: #7044 - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
e288410e72
commit
7d92e9407b
@ -1,6 +1,7 @@
|
||||
"""BasePrompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Set
|
||||
|
||||
@ -26,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
"""
|
||||
Validate that the input variables are valid for the template.
|
||||
Raise an exception if missing or extra variables are found.
|
||||
Issues an warning if missing or extra variables are found.
|
||||
|
||||
Args:
|
||||
template: The template string.
|
||||
@ -37,15 +38,15 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||
missing_variables = valid_variables - input_variables_set
|
||||
extra_variables = input_variables_set - valid_variables
|
||||
|
||||
error_message = ""
|
||||
warning_message = ""
|
||||
if missing_variables:
|
||||
error_message += f"Missing variables: {missing_variables} "
|
||||
warning_message += f"Missing variables: {missing_variables} "
|
||||
|
||||
if extra_variables:
|
||||
error_message += f"Extra variables: {extra_variables}"
|
||||
warning_message += f"Extra variables: {extra_variables}"
|
||||
|
||||
if error_message:
|
||||
raise KeyError(error_message.strip())
|
||||
if warning_message:
|
||||
warnings.warn(warning_message.strip())
|
||||
|
||||
|
||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||
|
@ -227,7 +227,7 @@ def test_prompt_jinja2_missing_input_variables(
|
||||
suffix = "Ending with {{ bar }}"
|
||||
|
||||
# Test when missing in suffix
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=suffix,
|
||||
@ -237,7 +237,7 @@ def test_prompt_jinja2_missing_input_variables(
|
||||
)
|
||||
|
||||
# Test when missing in prefix
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar"],
|
||||
suffix=suffix,
|
||||
@ -254,7 +254,7 @@ def test_prompt_jinja2_extra_input_variables(
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
prefix = "Starting with {{ foo }}"
|
||||
suffix = "Ending with {{ bar }}"
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
FewShotPromptTemplate(
|
||||
input_variables=["bar", "foo", "extra", "thing"],
|
||||
suffix=suffix,
|
||||
|
@ -218,7 +218,7 @@ def test_prompt_jinja2_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
@ -228,7 +228,7 @@ def test_prompt_jinja2_extra_input_variables() -> None:
|
||||
"""Test error is raised when there are too many input variables."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
@ -238,7 +238,7 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||
"""Test error is raised when name of input variable is wrong."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user