|
|
|
@ -47,7 +47,7 @@ class PromptTemplate(StringPromptTemplate):
|
|
|
|
|
prompt.format(foo="bar")
|
|
|
|
|
|
|
|
|
|
# Instantiation using initializer
|
|
|
|
|
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
|
|
|
|
prompt = PromptTemplate(template="Say {foo}")
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@ -74,6 +74,43 @@ class PromptTemplate(StringPromptTemplate):
|
|
|
|
|
validate_template: bool = False
|
|
|
|
|
"""Whether or not to try validating the template."""
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def pre_init_validation(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Check that template and input variables are consistent."""
|
|
|
|
|
if values.get("template") is None:
|
|
|
|
|
# Will let pydantic fail with a ValidationError if template
|
|
|
|
|
# is not provided.
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
# Set some default values based on the field defaults
|
|
|
|
|
values.setdefault("template_format", "f-string")
|
|
|
|
|
values.setdefault("partial_variables", {})
|
|
|
|
|
|
|
|
|
|
if values.get("validate_template"):
|
|
|
|
|
if values["template_format"] == "mustache":
|
|
|
|
|
raise ValueError("Mustache templates cannot be validated.")
|
|
|
|
|
|
|
|
|
|
if "input_variables" not in values:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Input variables must be provided to validate the template."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
|
|
|
|
check_valid_template(
|
|
|
|
|
values["template"], values["template_format"], all_inputs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if values["template_format"]:
|
|
|
|
|
values["input_variables"] = [
|
|
|
|
|
var
|
|
|
|
|
for var in get_template_variables(
|
|
|
|
|
values["template"], values["template_format"]
|
|
|
|
|
)
|
|
|
|
|
if var not in values["partial_variables"]
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
|
|
|
|
|
if self.template_format != "mustache":
|
|
|
|
|
return super().get_input_schema(config)
|
|
|
|
@ -126,26 +163,6 @@ class PromptTemplate(StringPromptTemplate):
|
|
|
|
|
kwargs = self._merge_partial_and_user_variables(**kwargs)
|
|
|
|
|
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def template_is_valid(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Check that template and input variables are consistent."""
|
|
|
|
|
if values["validate_template"]:
|
|
|
|
|
if values["template_format"] == "mustache":
|
|
|
|
|
raise ValueError("Mustache templates cannot be validated.")
|
|
|
|
|
all_inputs = values["input_variables"] + list(values["partial_variables"])
|
|
|
|
|
check_valid_template(
|
|
|
|
|
values["template"], values["template_format"], all_inputs
|
|
|
|
|
)
|
|
|
|
|
elif values.get("template_format"):
|
|
|
|
|
values["input_variables"] = [
|
|
|
|
|
var
|
|
|
|
|
for var in get_template_variables(
|
|
|
|
|
values["template"], values["template_format"]
|
|
|
|
|
)
|
|
|
|
|
if var not in values["partial_variables"]
|
|
|
|
|
]
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_examples(
|
|
|
|
|
cls,
|
|
|
|
|