diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 3d1a13e769..e527cd8276 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -1,6 +1,7 @@ """BasePrompt schema definition.""" from __future__ import annotations +import warnings from abc import ABC from typing import Any, Callable, Dict, List, Set @@ -26,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: def validate_jinja2(template: str, input_variables: List[str]) -> None: """ Validate that the input variables are valid for the template. - Raise an exception if missing or extra variables are found. + Issues an warning if missing or extra variables are found. Args: template: The template string. @@ -37,15 +38,15 @@ def validate_jinja2(template: str, input_variables: List[str]) -> None: missing_variables = valid_variables - input_variables_set extra_variables = input_variables_set - valid_variables - error_message = "" + warning_message = "" if missing_variables: - error_message += f"Missing variables: {missing_variables} " + warning_message += f"Missing variables: {missing_variables} " if extra_variables: - error_message += f"Extra variables: {extra_variables}" + warning_message += f"Extra variables: {extra_variables}" - if error_message: - raise KeyError(error_message.strip()) + if warning_message: + warnings.warn(warning_message.strip()) def _get_jinja2_variables_from_template(template: str) -> Set[str]: diff --git a/tests/unit_tests/prompts/test_few_shot.py b/tests/unit_tests/prompts/test_few_shot.py index eb73c4c16e..fae1e884c5 100644 --- a/tests/unit_tests/prompts/test_few_shot.py +++ b/tests/unit_tests/prompts/test_few_shot.py @@ -227,7 +227,7 @@ def test_prompt_jinja2_missing_input_variables( suffix = "Ending with {{ bar }}" # Test when missing in suffix - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=[], suffix=suffix, @@ -237,7 +237,7 @@ def test_prompt_jinja2_missing_input_variables( ) # Test when missing in prefix - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=["bar"], suffix=suffix, @@ -254,7 +254,7 @@ def test_prompt_jinja2_extra_input_variables( """Test error is raised when there are too many input variables.""" prefix = "Starting with {{ foo }}" suffix = "Ending with {{ bar }}" - with pytest.raises(ValueError): + with pytest.warns(UserWarning): FewShotPromptTemplate( input_variables=["bar", "foo", "extra", "thing"], suffix=suffix, diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index 4902781539..5a0e9d2432 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -218,7 +218,7 @@ def test_prompt_jinja2_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {{ foo }} test." input_variables: list = [] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" ) @@ -228,7 +228,7 @@ def test_prompt_jinja2_extra_input_variables() -> None: """Test error is raised when there are too many input variables.""" template = "This is a {{ foo }} test." input_variables = ["foo", "bar"] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" ) @@ -238,7 +238,7 @@ def test_prompt_jinja2_wrong_input_variables() -> None: """Test error is raised when name of input variable is wrong.""" template = "This is a {{ foo }} test." input_variables = ["bar"] - with pytest.raises(ValueError): + with pytest.warns(UserWarning): PromptTemplate( input_variables=input_variables, template=template, template_format="jinja2" )