From dbbc340f2556352360cbc3f1251e3d6bde57848c Mon Sep 17 00:00:00 2001 From: engkheng <60956360+outday29@users.noreply.github.com> Date: Thu, 20 Apr 2023 07:18:32 +0800 Subject: [PATCH] Validate `input_variables` when using `jinja2` templates (#3140) `langchain.prompts.PromptTemplate` and `langchain.prompts.FewShotPromptTemplate` do not validate `input_variables` when initialized as `jinja2` template. ```python # Using langchain v0.0.144 template = """"\ Your variable: {{ foo }} {% if bar %} You just set bar boolean variable to true {% endif %} """ # Missing variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar"], template_format="jinja2", validate_template=True) # Extra variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar", "foo", "extra", "thing"], template_format="jinja2", validate_template=True) ``` --- langchain/formatting.py | 8 +- langchain/prompts/base.py | 43 ++++++++++- langchain/prompts/prompt.py | 17 +---- tests/unit_tests/prompts/test_few_shot.py | 89 +++++++++++++++++++++++ tests/unit_tests/prompts/test_prompt.py | 30 ++++++++ 5 files changed, 167 insertions(+), 20 deletions(-) diff --git a/langchain/formatting.py b/langchain/formatting.py index 61c7c116..3b3b597b 100644 --- a/langchain/formatting.py +++ b/langchain/formatting.py @@ -1,6 +1,6 @@ """Utilities for formatting strings.""" from string import Formatter -from typing import Any, Mapping, Sequence, Union +from typing import Any, List, Mapping, Sequence, Union class StrictFormatter(Formatter): @@ -28,5 +28,11 @@ class StrictFormatter(Formatter): ) return super().vformat(format_string, args, kwargs) + def validate_input_variables( + self, format_string: str, input_variables: List[str] + ) -> None: + dummy_inputs = {input_variable: "foo" for input_variable in input_variables} + super().format(format_string, **dummy_inputs) + formatter = StrictFormatter() diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index c0c747ae..8d31b10e 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -4,7 +4,7 @@ from __future__ import annotations import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union import yaml from pydantic import BaseModel, Extra, Field, root_validator @@ -26,11 +26,47 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str: return Template(template).render(**kwargs) +def validate_jinja2(template: str, input_variables: List[str]) -> None: + input_variables_set = set(input_variables) + valid_variables = _get_jinja2_variables_from_template(template) + missing_variables = valid_variables - input_variables_set + extra_variables = input_variables_set - valid_variables + + error_message = "" + if missing_variables: + error_message += f"Missing variables: {missing_variables} " + + if extra_variables: + error_message += f"Extra variables: {extra_variables}" + + if error_message: + raise KeyError(error_message.strip()) + + +def _get_jinja2_variables_from_template(template: str) -> Set[str]: + try: + from jinja2 import Environment, meta + except ImportError: + raise ImportError( + "jinja2 not installed, which is needed to use the jinja2_formatter. " + "Please install it with `pip install jinja2`." + ) + env = Environment() + ast = env.parse(template) + variables = meta.find_undeclared_variables(ast) + return variables + + DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { "f-string": formatter.format, "jinja2": jinja2_formatter, } +DEFAULT_VALIDATOR_MAPPING: Dict[str, Callable] = { + "f-string": formatter.validate_input_variables, + "jinja2": validate_jinja2, +} + def check_valid_template( template: str, template_format: str, input_variables: List[str] @@ -42,10 +78,9 @@ def check_valid_template( f"Invalid template format. Got `{template_format}`;" f" should be one of {valid_formats}" ) - dummy_inputs = {input_variable: "foo" for input_variable in input_variables} try: - formatter_func = DEFAULT_FORMATTER_MAPPING[template_format] - formatter_func(template, **dummy_inputs) + validator_func = DEFAULT_VALIDATOR_MAPPING[template_format] + validator_func(template, input_variables) except KeyError as e: raise ValueError( "Invalid prompt schema; check for mismatched or missing input parameters. " diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index af6ff29e..c61cf69f 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -3,31 +3,18 @@ from __future__ import annotations from pathlib import Path from string import Formatter -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Union from pydantic import Extra, root_validator from langchain.prompts.base import ( DEFAULT_FORMATTER_MAPPING, StringPromptTemplate, + _get_jinja2_variables_from_template, check_valid_template, ) -def _get_jinja2_variables_from_template(template: str) -> Set[str]: - try: - from jinja2 import Environment, meta - except ImportError: - raise ImportError( - "jinja2 not installed, which is needed to use the jinja2_formatter. " - "Please install it with `pip install jinja2`." - ) - env = Environment() - ast = env.parse(template) - variables = meta.find_undeclared_variables(ast) - return variables - - class PromptTemplate(StringPromptTemplate): """Schema to represent a prompt for an LLM. diff --git a/tests/unit_tests/prompts/test_few_shot.py b/tests/unit_tests/prompts/test_few_shot.py index 22ba8e51..eb73c4c1 100644 --- a/tests/unit_tests/prompts/test_few_shot.py +++ b/tests/unit_tests/prompts/test_few_shot.py @@ -1,4 +1,6 @@ """Test few shot prompt template.""" +from typing import Dict, List, Tuple + import pytest from langchain.prompts.few_shot import FewShotPromptTemplate @@ -9,6 +11,25 @@ EXAMPLE_PROMPT = PromptTemplate( ) +@pytest.fixture() +def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]: + example_template = "{{ word }}: {{ antonym }}" + + examples = [ + {"word": "happy", "antonym": "sad"}, + {"word": "tall", "antonym": "short"}, + ] + + return ( + PromptTemplate( + input_variables=["word", "antonym"], + template=example_template, + template_format="jinja2", + ), + examples, + ) + + def test_suffix_only() -> None: """Test prompt works with just a suffix.""" suffix = "This is a {foo} test." @@ -174,3 +195,71 @@ def test_partial() -> None: "Now you try to talk about party." ) assert output == expected_output + + +def test_prompt_jinja2_functionality( + example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]] +) -> None: + prefix = "Starting with {{ foo }}" + suffix = "Ending with {{ bar }}" + + prompt = FewShotPromptTemplate( + input_variables=["foo", "bar"], + suffix=suffix, + prefix=prefix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ) + output = prompt.format(foo="hello", bar="bye") + expected_output = ( + "Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye" + ) + + assert output == expected_output + + +def test_prompt_jinja2_missing_input_variables( + example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]] +) -> None: + """Test error is raised when input variables are not provided.""" + prefix = "Starting with {{ foo }}" + suffix = "Ending with {{ bar }}" + + # Test when missing in suffix + with pytest.raises(ValueError): + FewShotPromptTemplate( + input_variables=[], + suffix=suffix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ) + + # Test when missing in prefix + with pytest.raises(ValueError): + FewShotPromptTemplate( + input_variables=["bar"], + suffix=suffix, + prefix=prefix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ) + + +def test_prompt_jinja2_extra_input_variables( + example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]] +) -> None: + """Test error is raised when there are too many input variables.""" + prefix = "Starting with {{ foo }}" + suffix = "Ending with {{ bar }}" + with pytest.raises(ValueError): + FewShotPromptTemplate( + input_variables=["bar", "foo", "extra", "thing"], + suffix=suffix, + prefix=prefix, + examples=example_jinja2_prompt[1], + example_prompt=example_jinja2_prompt[0], + template_format="jinja2", + ) diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index db55fb11..49027815 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -212,3 +212,33 @@ Your variable again: {{ foo }} template_format="jinja2", ) assert prompt == expected_prompt + + +def test_prompt_jinja2_missing_input_variables() -> None: + """Test error is raised when input variables are not provided.""" + template = "This is a {{ foo }} test." + input_variables: list = [] + with pytest.raises(ValueError): + PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ) + + +def test_prompt_jinja2_extra_input_variables() -> None: + """Test error is raised when there are too many input variables.""" + template = "This is a {{ foo }} test." + input_variables = ["foo", "bar"] + with pytest.raises(ValueError): + PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + ) + + +def test_prompt_jinja2_wrong_input_variables() -> None: + """Test error is raised when name of input variable is wrong.""" + template = "This is a {{ foo }} test." + input_variables = ["bar"] + with pytest.raises(ValueError): + PromptTemplate( + input_variables=input_variables, template=template, template_format="jinja2" + )