Include "no escape" and "inverted section" mustache vars in Prompt.input_variables and Prompt.input_schema (#22981)

pull/23063/head
Nuno Campos 3 weeks ago committed by GitHub
parent 7a0b36501f
commit f01f12ce1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -103,9 +103,12 @@ def mustache_template_vars(
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
elif (
type in ("variable", "section", "inverted section", "no escape")
and key != "."
):
vars.add(key.split(".")[0])
if type == "section":
if type in ("section", "inverted section"):
in_section = True
return vars
@ -117,24 +120,25 @@ def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""Get the variables from a mustache template."""
fields = set()
fields = {}
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
elif type in ("section", "inverted section"):
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
fields[prefix] = False
elif type in ("variable", "no escape"):
fields[prefix + tuple(key.split("."))] = True
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
field, is_leaf = fields.popitem()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
current.setdefault(field[-1], "" if is_leaf else {}) # type: ignore[arg-type]
return _create_model_recursive("PromptInput", defs)
@ -142,7 +146,7 @@ def _create_model_recursive(name: str, defs: Defs) -> Type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
for k, v in defs.items()
},
)

@ -67,7 +67,7 @@ def test_mustache_prompt_from_template() -> None:
}
# Multiple input variables with repeats.
template = "This {{bar}} is a {{foo}} test {{foo}}."
template = "This {{bar}} is a {{foo}} test {{&foo}}."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
assert prompt.input_variables == ["bar", "foo"]
@ -81,7 +81,7 @@ def test_mustache_prompt_from_template() -> None:
}
# Nested variables.
template = "This {{obj.bar}} is a {{obj.foo}} test {{foo}}."
template = "This {{obj.bar}} is a {{obj.foo}} test {{{foo}}}."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(obj={"bar": "foo", "foo": "bar"}, foo="baz") == (
"This foo is a bar test baz."
@ -167,6 +167,22 @@ def test_mustache_prompt_from_template() -> None:
},
}
template = """This{{^foo}}
no foos
{{/foo}}is a test."""
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format() == (
"""This
no foos
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "object"}},
}
def test_prompt_from_template_with_partial_variables() -> None:
"""Test prompts can be constructed from a template with partial variables."""

Loading…
Cancel
Save