mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Jinja2 validation changed to issue warnings rather than issuing exceptions. (#7161)
- Description: If their are missing or extra variables when validating Jinja 2 template then a warning is issued rather than raising an exception. This allows for better flexibility for the developer as described in #7044. Also changed the relevant test so pytest is checking for raised warnings rather than exceptions. - Issue: #7044 - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
e288410e72
commit
7d92e9407b
@ -1,6 +1,7 @@
|
|||||||
"""BasePrompt schema definition."""
|
"""BasePrompt schema definition."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Callable, Dict, List, Set
|
from typing import Any, Callable, Dict, List, Set
|
||||||
|
|
||||||
@ -26,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
|
|||||||
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Validate that the input variables are valid for the template.
|
Validate that the input variables are valid for the template.
|
||||||
Raise an exception if missing or extra variables are found.
|
Issues an warning if missing or extra variables are found.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: The template string.
|
template: The template string.
|
||||||
@ -37,15 +38,15 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None:
|
|||||||
missing_variables = valid_variables - input_variables_set
|
missing_variables = valid_variables - input_variables_set
|
||||||
extra_variables = input_variables_set - valid_variables
|
extra_variables = input_variables_set - valid_variables
|
||||||
|
|
||||||
error_message = ""
|
warning_message = ""
|
||||||
if missing_variables:
|
if missing_variables:
|
||||||
error_message += f"Missing variables: {missing_variables} "
|
warning_message += f"Missing variables: {missing_variables} "
|
||||||
|
|
||||||
if extra_variables:
|
if extra_variables:
|
||||||
error_message += f"Extra variables: {extra_variables}"
|
warning_message += f"Extra variables: {extra_variables}"
|
||||||
|
|
||||||
if error_message:
|
if warning_message:
|
||||||
raise KeyError(error_message.strip())
|
warnings.warn(warning_message.strip())
|
||||||
|
|
||||||
|
|
||||||
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
def _get_jinja2_variables_from_template(template: str) -> Set[str]:
|
||||||
|
@ -227,7 +227,7 @@ def test_prompt_jinja2_missing_input_variables(
|
|||||||
suffix = "Ending with {{ bar }}"
|
suffix = "Ending with {{ bar }}"
|
||||||
|
|
||||||
# Test when missing in suffix
|
# Test when missing in suffix
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
FewShotPromptTemplate(
|
FewShotPromptTemplate(
|
||||||
input_variables=[],
|
input_variables=[],
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
@ -237,7 +237,7 @@ def test_prompt_jinja2_missing_input_variables(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test when missing in prefix
|
# Test when missing in prefix
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
FewShotPromptTemplate(
|
FewShotPromptTemplate(
|
||||||
input_variables=["bar"],
|
input_variables=["bar"],
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
@ -254,7 +254,7 @@ def test_prompt_jinja2_extra_input_variables(
|
|||||||
"""Test error is raised when there are too many input variables."""
|
"""Test error is raised when there are too many input variables."""
|
||||||
prefix = "Starting with {{ foo }}"
|
prefix = "Starting with {{ foo }}"
|
||||||
suffix = "Ending with {{ bar }}"
|
suffix = "Ending with {{ bar }}"
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
FewShotPromptTemplate(
|
FewShotPromptTemplate(
|
||||||
input_variables=["bar", "foo", "extra", "thing"],
|
input_variables=["bar", "foo", "extra", "thing"],
|
||||||
suffix=suffix,
|
suffix=suffix,
|
||||||
|
@ -218,7 +218,7 @@ def test_prompt_jinja2_missing_input_variables() -> None:
|
|||||||
"""Test error is raised when input variables are not provided."""
|
"""Test error is raised when input variables are not provided."""
|
||||||
template = "This is a {{ foo }} test."
|
template = "This is a {{ foo }} test."
|
||||||
input_variables: list = []
|
input_variables: list = []
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
)
|
)
|
||||||
@ -228,7 +228,7 @@ def test_prompt_jinja2_extra_input_variables() -> None:
|
|||||||
"""Test error is raised when there are too many input variables."""
|
"""Test error is raised when there are too many input variables."""
|
||||||
template = "This is a {{ foo }} test."
|
template = "This is a {{ foo }} test."
|
||||||
input_variables = ["foo", "bar"]
|
input_variables = ["foo", "bar"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
)
|
)
|
||||||
@ -238,7 +238,7 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
|
|||||||
"""Test error is raised when name of input variable is wrong."""
|
"""Test error is raised when name of input variable is wrong."""
|
||||||
template = "This is a {{ foo }} test."
|
template = "This is a {{ foo }} test."
|
||||||
input_variables = ["bar"]
|
input_variables = ["bar"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user