forked from Archives/langchain
dbbc340f25
`langchain.prompts.PromptTemplate` and `langchain.prompts.FewShotPromptTemplate` do not validate `input_variables` when initialized as `jinja2` template. ```python # Using langchain v0.0.144 template = """"\ Your variable: {{ foo }} {% if bar %} You just set bar boolean variable to true {% endif %} """ # Missing variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar"], template_format="jinja2", validate_template=True) # Extra variable, should raise ValueError prompt_template = PromptTemplate(template=template, input_variables=["bar", "foo", "extra", "thing"], template_format="jinja2", validate_template=True) ```
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
"""Utilities for formatting strings."""
|
|
from string import Formatter
|
|
from typing import Any, List, Mapping, Sequence, Union
|
|
|
|
|
|
class StrictFormatter(Formatter):
|
|
"""A subclass of formatter that checks for extra keys."""
|
|
|
|
def check_unused_args(
|
|
self,
|
|
used_args: Sequence[Union[int, str]],
|
|
args: Sequence,
|
|
kwargs: Mapping[str, Any],
|
|
) -> None:
|
|
"""Check to see if extra parameters are passed."""
|
|
extra = set(kwargs).difference(used_args)
|
|
if extra:
|
|
raise KeyError(extra)
|
|
|
|
def vformat(
|
|
self, format_string: str, args: Sequence, kwargs: Mapping[str, Any]
|
|
) -> str:
|
|
"""Check that no arguments are provided."""
|
|
if len(args) > 0:
|
|
raise ValueError(
|
|
"No arguments should be provided, "
|
|
"everything should be passed as keyword arguments."
|
|
)
|
|
return super().vformat(format_string, args, kwargs)
|
|
|
|
def validate_input_variables(
|
|
self, format_string: str, input_variables: List[str]
|
|
) -> None:
|
|
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
|
|
super().format(format_string, **dummy_inputs)
|
|
|
|
|
|
formatter = StrictFormatter()
|