From 7d92e9407bfe24d45d29abebde0734040871b109 Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Wed, 5 Jul 2023 23:04:29 +0500 Subject: [PATCH] Jinja2 validation changed to issue warnings rather than issuing exceptions. (#7161) - Description: If their are missing or extra variables when validating Jinja 2 template then a warning is issued rather than raising an exception. This allows for better flexibility for the developer as described in #7044. Also changed the relevant test so pytest is checking for raised warnings rather than exceptions. - Issue: #7044 - Tag maintainer: @hwchase17, @baskaryan --------- Co-authored-by: Harrison Chase --- langchain/prompts/base.py | 13 +++++++------ tests/unit_tests/prompts/test_few_shot.py | 6 +++--- tests/unit_tests/prompts/test_prompt.py | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 3d1a13e769..e527cd8276 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,6 +1,7 @@ """BasePrompt schema definition.""" from __future__ import annotations +import warnings from abc import ABC from typing import Any, Callable, Dict, List, Set @@ -26,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: def validate_jinja2(template: str, input_variables: List[str]) -> None: """ Validate that the input variables are valid for the template. - Raise an exception if missing or extra variables are found. + Issues an warning if missing or extra variables are found. Args: template: The template string. @@ -37,15 +38,15 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None: missing_variables = valid_variables - input_variables_set extra_variables = input_variables_set - valid_variables - error_message = "" + warning_message = "" if missing_variables: - error_message += f"Missing variables: {missing_variables} " + warning_message += f"Missing variables: {missing_variables} " if extra_variables: - error_message += f"Extra variables: {extra_variables}" + warning_message += f"Extra variables: {extra_variables}" - if error_message: - raise KeyError(error_message.strip()) + if warning_message: + warnings.warn(warning_message.strip()) def _get_jinja2_variables_from_template(template: str) -> Set[str]: diff --git a/tests/unit_tests/prompts/test_few_shot.py b/tests/unit_tests/prompts/test_few_shot.py index eb73c4c16e..fae1e884c5 100644 --- a/tests/unit_tests/prompts/test_few_shot.py +++ b/tests/unit_tests/prompts/test_few_shot.py @@ -227,7 +227,7 @@ def test_prompt_jinja2_missing_input_variables( suffix = "Ending with {{ bar }}" # Test when missing in suffix - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=[], suffix=suffix, @@ -237,7 +237,7 @@ def test_prompt_jinja2_missing_input_variables( ) # Test when missing in prefix - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=["bar"], suffix=suffix, @@ -254,7 +254,7 @@ def test_prompt_jinja2_extra_input_variables( """Test error is raised when there are too many input variables.""" prefix = "Starting with {{ foo }}" suffix = "Ending with {{ bar }}" - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=["bar", "foo", "extra", "thing"], suffix=suffix, diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index 4902781539..5a0e9d2432 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -218,7 +218,7 @@ def test_prompt_jinja2_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {{ foo }} test." input_variables: list = [] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" ) @@ -228,7 +228,7 @@ def test_prompt_jinja2_extra_input_variables() -> None: """Test error is raised when there are too many input variables.""" template = "This is a {{ foo }} test." input_variables = ["foo", "bar"] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" ) @@ -238,7 +238,7 @@ def test_prompt_jinja2_wrong_input_variables() -> None: """Test error is raised when name of input variable is wrong.""" template = "This is a {{ foo }} test." input_variables = ["bar"] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" )