Fixing empty input variable crashing PromptTemplate validations (#14314)

- Fixes `input_variables=[""]` crashing validations with a template
`"{}"`
- Uses `__cause__` for proper `Exception` chaining in
`check_valid_template`
This commit is contained in:
James Braza 2023-12-05 16:13:08 -05:00 committed by GitHub
parent 0f02e94565
commit 8b0060184d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 11 deletions

View File

@ -106,20 +106,20 @@ def check_valid_template(
Raises: Raises:
ValueError: If the template format is not supported. ValueError: If the template format is not supported.
""" """
if template_format not in DEFAULT_FORMATTER_MAPPING:
valid_formats = list(DEFAULT_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try: try:
validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] validator_func = DEFAULT_VALIDATOR_MAPPING[template_format]
validator_func(template, input_variables) except KeyError as exc:
except KeyError as e:
raise ValueError( raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters. " f"Invalid template format {template_format!r}, should be one of"
+ str(e) f" {list(DEFAULT_FORMATTER_MAPPING)}."
) ) from exc
try:
validator_func(template, input_variables)
except (KeyError, IndexError) as exc:
raise ValueError(
"Invalid prompt schema; check for mismatched or missing input parameters"
f" from {input_variables}."
) from exc
def get_template_variables(template: str, template_format: str) -> List[str]: def get_template_variables(template: str, template_format: str) -> List[str]:

View File

@ -47,6 +47,12 @@ def test_prompt_missing_input_variables() -> None:
).input_variables == ["foo"] ).input_variables == ["foo"]
def test_prompt_empty_input_variable() -> None:
"""Test error is raised when empty string input variable."""
with pytest.raises(ValueError):
PromptTemplate(input_variables=[""], template="{}", validate_template=True)
def test_prompt_extra_input_variables() -> None: def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables.""" """Test error is raised when there are too many input variables."""
template = "This is a {foo} test." template = "This is a {foo} test."