From f01f12ce1e4b0585d31450776cad77a6277358b7 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 17 Jun 2024 19:24:13 -0700 Subject: [PATCH] Include "no escape" and "inverted section" mustache vars in Prompt.input_variables and Prompt.input_schema (#22981) --- libs/core/langchain_core/prompts/string.py | 22 +++++++++++-------- .../tests/unit_tests/prompts/test_prompt.py | 20 +++++++++++++++-- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index 4abbd30111..0d8c8ce117 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -103,9 +103,12 @@ def mustache_template_vars( in_section = False elif in_section: continue - elif type in ("variable", "section") and key != ".": + elif ( + type in ("variable", "section", "inverted section", "no escape") + and key != "." + ): vars.add(key.split(".")[0]) - if type == "section": + if type in ("section", "inverted section"): in_section = True return vars @@ -117,24 +120,25 @@ def mustache_schema( template: str, ) -> Type[BaseModel]: """Get the variables from a mustache template.""" - fields = set() + fields = {} prefix: Tuple[str, ...] = () for type, key in mustache.tokenize(template): if key == ".": continue if type == "end": prefix = prefix[: -key.count(".")] - elif type == "section": + elif type in ("section", "inverted section"): prefix = prefix + tuple(key.split(".")) - elif type == "variable": - fields.add(prefix + tuple(key.split("."))) + fields[prefix] = False + elif type in ("variable", "no escape"): + fields[prefix + tuple(key.split("."))] = True defs: Defs = {} # None means leaf node while fields: - field = fields.pop() + field, is_leaf = fields.popitem() current = defs for part in field[:-1]: current = current.setdefault(part, {}) - current[field[-1]] = {} + current.setdefault(field[-1], "" if is_leaf else {}) # type: ignore[arg-type] return _create_model_recursive("PromptInput", defs) @@ -142,7 +146,7 @@ def _create_model_recursive(name: str, defs: Defs) -> Type: return create_model( # type: ignore[call-overload] name, **{ - k: (_create_model_recursive(k, v), None) if v else (str, None) + k: (_create_model_recursive(k, v), None) if v else (type(v), None) for k, v in defs.items() }, ) diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 3bb062d6ac..4c8423b7ee 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -67,7 +67,7 @@ def test_mustache_prompt_from_template() -> None: } # Multiple input variables with repeats. - template = "This {{bar}} is a {{foo}} test {{foo}}." + template = "This {{bar}} is a {{foo}} test {{&foo}}." prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar." assert prompt.input_variables == ["bar", "foo"] @@ -81,7 +81,7 @@ def test_mustache_prompt_from_template() -> None: } # Nested variables. - template = "This {{obj.bar}} is a {{obj.foo}} test {{foo}}." + template = "This {{obj.bar}} is a {{obj.foo}} test {{{foo}}}." prompt = PromptTemplate.from_template(template, template_format="mustache") assert prompt.format(obj={"bar": "foo", "foo": "bar"}, foo="baz") == ( "This foo is a bar test baz." @@ -167,6 +167,22 @@ def test_mustache_prompt_from_template() -> None: }, } + template = """This{{^foo}} + no foos + {{/foo}}is a test.""" + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format() == ( + """This + no foos + is a test.""" + ) + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"title": "Foo", "type": "object"}}, + } + def test_prompt_from_template_with_partial_variables() -> None: """Test prompts can be constructed from a template with partial variables."""