diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index e35050bc64..4316dc5bdd 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -8,6 +8,7 @@ from typing import ( Any, Dict, List, + Literal, Optional, Sequence, Set, @@ -929,6 +930,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): def from_messages( cls, messages: Sequence[MessageLikeRepresentation], + template_format: Literal["f-string", "mustache"] = "f-string", ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -964,7 +966,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate): Returns: a chat prompt template """ - _messages = [_convert_to_message(message) for message in messages] + _messages = [ + _convert_to_message(message, template_format) for message in messages + ] # Automatically infer input variables from messages input_vars: Set[str] = set() @@ -1121,7 +1125,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate): def _create_template_from_message_type( - message_type: str, template: Union[str, list] + message_type: str, + template: Union[str, list], + template_format: Literal["f-string", "mustache"] = "f-string", ) -> BaseMessagePromptTemplate: """Create a message prompt template from a message type and template string. @@ -1134,12 +1140,16 @@ def _create_template_from_message_type( """ if message_type in ("human", "user"): message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template( - template + template, template_format=template_format ) elif message_type in ("ai", "assistant"): - message = AIMessagePromptTemplate.from_template(cast(str, template)) + message = AIMessagePromptTemplate.from_template( + cast(str, template), template_format=template_format + ) elif message_type == "system": - message = SystemMessagePromptTemplate.from_template(cast(str, template)) + message = SystemMessagePromptTemplate.from_template( + cast(str, template), template_format=template_format + ) elif message_type == "placeholder": if isinstance(template, str): if template[0] != "{" or template[-1] != "}": @@ -1180,6 +1190,7 @@ def _create_template_from_message_type( def _convert_to_message( message: MessageLikeRepresentation, + template_format: Literal["f-string", "mustache"] = "f-string", ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: """Instantiate a message from a variety of message formats. @@ -1204,16 +1215,22 @@ def _convert_to_message( elif isinstance(message, BaseMessage): _message = message elif isinstance(message, str): - _message = _create_template_from_message_type("human", message) + _message = _create_template_from_message_type( + "human", message, template_format=template_format + ) elif isinstance(message, tuple): if len(message) != 2: raise ValueError(f"Expected 2-tuple of (role, template), got {message}") message_type_str, template = message if isinstance(message_type_str, str): - _message = _create_template_from_message_type(message_type_str, template) + _message = _create_template_from_message_type( + message_type_str, template, template_format=template_format + ) else: _message = message_type_str( - prompt=PromptTemplate.from_template(cast(str, template)) + prompt=PromptTemplate.from_template( + cast(str, template), template_format=template_format + ) ) else: raise NotImplementedError(f"Unsupported message type: {type(message)}") diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index f3c53a0e95..e909ee9088 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -10,8 +10,10 @@ from langchain_core.prompts.string import ( StringPromptTemplate, check_valid_template, get_template_variables, + mustache_schema, ) -from langchain_core.pydantic_v1 import root_validator +from langchain_core.pydantic_v1 import BaseModel, root_validator +from langchain_core.runnables.config import RunnableConfig class PromptTemplate(StringPromptTemplate): @@ -65,12 +67,19 @@ class PromptTemplate(StringPromptTemplate): template: str """The prompt template.""" - template_format: Literal["f-string", "jinja2"] = "f-string" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + template_format: Literal["f-string", "mustache", "jinja2"] = "f-string" + """The format of the prompt template. + Options are: 'f-string', 'mustache', 'jinja2'.""" validate_template: bool = False """Whether or not to try validating the template.""" + def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: + if self.template_format != "mustache": + return super().get_input_schema(config) + + return mustache_schema(self.template) + def __add__(self, other: Any) -> PromptTemplate: """Override the + operator to allow for combining prompt templates.""" # Allow for easy combining @@ -121,6 +130,8 @@ class PromptTemplate(StringPromptTemplate): def template_is_valid(cls, values: Dict) -> Dict: """Check that template and input variables are consistent.""" if values["validate_template"]: + if values["template_format"] == "mustache": + raise ValueError("Mustache templates cannot be validated.") all_inputs = values["input_variables"] + list(values["partial_variables"]) check_valid_template( values["template"], values["template_format"], all_inputs diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index b324871da5..4abbd30111 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -5,10 +5,12 @@ from __future__ import annotations import warnings from abc import ABC from string import Formatter -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Set, Tuple, Type +import langchain_core.utils.mustache as mustache from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.pydantic_v1 import BaseModel, create_model from langchain_core.utils import get_colored_text from langchain_core.utils.formatting import formatter from langchain_core.utils.interactive_env import is_interactive_env @@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]: return variables +def mustache_formatter(template: str, **kwargs: Any) -> str: + """Format a template using mustache.""" + return mustache.render(template, kwargs) + + +def mustache_template_vars( + template: str, +) -> Set[str]: + """Get the variables from a mustache template.""" + vars: Set[str] = set() + in_section = False + for type, key in mustache.tokenize(template): + if type == "end": + in_section = False + elif in_section: + continue + elif type in ("variable", "section") and key != ".": + vars.add(key.split(".")[0]) + if type == "section": + in_section = True + return vars + + +Defs = Dict[str, "Defs"] + + +def mustache_schema( + template: str, +) -> Type[BaseModel]: + """Get the variables from a mustache template.""" + fields = set() + prefix: Tuple[str, ...] = () + for type, key in mustache.tokenize(template): + if key == ".": + continue + if type == "end": + prefix = prefix[: -key.count(".")] + elif type == "section": + prefix = prefix + tuple(key.split(".")) + elif type == "variable": + fields.add(prefix + tuple(key.split("."))) + defs: Defs = {} # None means leaf node + while fields: + field = fields.pop() + current = defs + for part in field[:-1]: + current = current.setdefault(part, {}) + current[field[-1]] = {} + return _create_model_recursive("PromptInput", defs) + + +def _create_model_recursive(name: str, defs: Defs) -> Type: + return create_model( # type: ignore[call-overload] + name, + **{ + k: (_create_model_recursive(k, v), None) if v else (str, None) + for k, v in defs.items() + }, + ) + + DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { "f-string": formatter.format, + "mustache": mustache_formatter, "jinja2": jinja2_formatter, } @@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]: input_variables = { v for _, v, _, _ in Formatter().parse(template) if v is not None } + elif template_format == "mustache": + input_variables = mustache_template_vars(template) else: raise ValueError(f"Unsupported template format: {template_format}") diff --git a/libs/core/langchain_core/utils/mustache.py b/libs/core/langchain_core/utils/mustache.py new file mode 100644 index 0000000000..06ea9cd002 --- /dev/null +++ b/libs/core/langchain_core/utils/mustache.py @@ -0,0 +1,641 @@ +""" +Adapted from https://github.com/noahmorrison/chevron +MIT License +""" + +import logging +from typing import ( + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from typing_extensions import TypeAlias + +logger = logging.getLogger(__name__) + + +Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]] + + +# Globals +_CURRENT_LINE = 1 +_LAST_TAG_LINE = None + + +class ChevronError(SyntaxError): + pass + + +# +# Helper functions +# + + +def grab_literal(template: str, l_del: str) -> Tuple[str, str]: + """Parse a literal from the template""" + + global _CURRENT_LINE + + try: + # Look for the next tag and move the template to it + literal, template = template.split(l_del, 1) + _CURRENT_LINE += literal.count("\n") + return (literal, template) + + # There are no more tags in the template? + except ValueError: + # Then the rest of the template is a literal + return (template, "") + + +def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool: + """Do a preliminary check to see if a tag could be a standalone""" + + # If there is a newline, or the previous tag was a standalone + if literal.find("\n") != -1 or is_standalone: + padding = literal.split("\n")[-1] + + # If all the characters since the last newline are spaces + if padding.isspace() or padding == "": + # Then the next tag could be a standalone + return True + else: + # Otherwise it can't be + return False + else: + return False + + +def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool: + """Do a final checkto see if a tag could be a standalone""" + + # Check right side if we might be a standalone + if is_standalone and tag_type not in ["variable", "no escape"]: + on_newline = template.split("\n", 1) + + # If the stuff to the right of us are spaces we're a standalone + if on_newline[0].isspace() or not on_newline[0]: + return True + else: + return False + + # If we're a tag can't be a standalone + else: + return False + + +def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]: + """Parse a tag from a template""" + global _CURRENT_LINE + global _LAST_TAG_LINE + + tag_types = { + "!": "comment", + "#": "section", + "^": "inverted section", + "/": "end", + ">": "partial", + "=": "set delimiter?", + "{": "no escape?", + "&": "no escape", + } + + # Get the tag + try: + tag, template = template.split(r_del, 1) + except ValueError: + raise ChevronError("unclosed tag " "at line {0}".format(_CURRENT_LINE)) + + # Find the type meaning of the first character + tag_type = tag_types.get(tag[0], "variable") + + # If the type is not a variable + if tag_type != "variable": + # Then that first character is not needed + tag = tag[1:] + + # If we might be a set delimiter tag + if tag_type == "set delimiter?": + # Double check to make sure we are + if tag.endswith("="): + tag_type = "set delimiter" + # Remove the equal sign + tag = tag[:-1] + + # Otherwise we should complain + else: + raise ChevronError( + "unclosed set delimiter tag\n" "at line {0}".format(_CURRENT_LINE) + ) + + # If we might be a no html escape tag + elif tag_type == "no escape?": + # And we have a third curly brace + # (And are using curly braces as delimiters) + if l_del == "{{" and r_del == "}}" and template.startswith("}"): + # Then we are a no html escape tag + template = template[1:] + tag_type = "no escape" + + # Strip the whitespace off the key and return + return ((tag_type, tag.strip()), template) + + +# +# The main tokenizing function +# + + +def tokenize( + template: str, def_ldel: str = "{{", def_rdel: str = "}}" +) -> Iterator[Tuple[str, str]]: + """Tokenize a mustache template + + Tokenizes a mustache template in a generator fashion, + using file-like objects. It also accepts a string containing + the template. + + + Arguments: + + template -- a file-like object, or a string of a mustache template + + def_ldel -- The default left delimiter + ("{{" by default, as in spec compliant mustache) + + def_rdel -- The default right delimiter + ("}}" by default, as in spec compliant mustache) + + + Returns: + + A generator of mustache tags in the form of a tuple + + -- (tag_type, tag_key) + + Where tag_type is one of: + * literal + * section + * inverted section + * end + * partial + * no escape + + And tag_key is either the key or in the case of a literal tag, + the literal itself. + """ + + global _CURRENT_LINE, _LAST_TAG_LINE + _CURRENT_LINE = 1 + _LAST_TAG_LINE = None + + is_standalone = True + open_sections = [] + l_del = def_ldel + r_del = def_rdel + + while template: + literal, template = grab_literal(template, l_del) + + # If the template is completed + if not template: + # Then yield the literal and leave + yield ("literal", literal) + break + + # Do the first check to see if we could be a standalone + is_standalone = l_sa_check(template, literal, is_standalone) + + # Parse the tag + tag, template = parse_tag(template, l_del, r_del) + tag_type, tag_key = tag + + # Special tag logic + + # If we are a set delimiter tag + if tag_type == "set delimiter": + # Then get and set the delimiters + dels = tag_key.strip().split(" ") + l_del, r_del = dels[0], dels[-1] + + # If we are a section tag + elif tag_type in ["section", "inverted section"]: + # Then open a new section + open_sections.append(tag_key) + _LAST_TAG_LINE = _CURRENT_LINE + + # If we are an end tag + elif tag_type == "end": + # Then check to see if the last opened section + # is the same as us + try: + last_section = open_sections.pop() + except IndexError: + raise ChevronError( + 'Trying to close tag "{0}"\n' + "Looks like it was not opened.\n" + "line {1}".format(tag_key, _CURRENT_LINE + 1) + ) + if tag_key != last_section: + # Otherwise we need to complain + raise ChevronError( + 'Trying to close tag "{0}"\n' + 'last open tag is "{1}"\n' + "line {2}".format(tag_key, last_section, _CURRENT_LINE + 1) + ) + + # Do the second check to see if we're a standalone + is_standalone = r_sa_check(template, tag_type, is_standalone) + + # Which if we are + if is_standalone: + # Remove the stuff before the newline + template = template.split("\n", 1)[-1] + + # Partials need to keep the spaces on their left + if tag_type != "partial": + # But other tags don't + literal = literal.rstrip(" ") + + # Start yielding + # Ignore literals that are empty + if literal != "": + yield ("literal", literal) + + # Ignore comments and set delimiters + if tag_type not in ["comment", "set delimiter?"]: + yield (tag_type, tag_key) + + # If there are any open sections when we're done + if open_sections: + # Then we need to complain + raise ChevronError( + "Unexpected EOF\n" + 'the tag "{0}" was never closed\n' + "was opened at line {1}".format(open_sections[-1], _LAST_TAG_LINE) + ) + + +# +# Helper functions +# + + +def _html_escape(string: str) -> str: + """HTML escape all of these " & < >""" + + html_codes = { + '"': """, + "<": "<", + ">": ">", + } + + # & must be handled first + string = string.replace("&", "&") + for char in html_codes: + string = string.replace(char, html_codes[char]) + return string + + +def _get_key( + key: str, + scopes: Scopes, + warn: bool, + keep: bool, + def_ldel: str, + def_rdel: str, +) -> Any: + """Get a key from the current scope""" + + # If the key is a dot + if key == ".": + # Then just return the current scope + return scopes[0] + + # Loop through the scopes + for scope in scopes: + try: + # Return an empty string if falsy, with two exceptions + # 0 should return 0, and False should return False + if scope in (0, False): + return scope + + # For every dot separated key + for child in key.split("."): + # Return an empty string if falsy, with two exceptions + # 0 should return 0, and False should return False + if scope in (0, False): + return scope + # Move into the scope + try: + # Try subscripting (Normal dictionaries) + scope = cast(Dict[str, Any], scope)[child] + except (TypeError, AttributeError): + try: + scope = getattr(scope, child) + except (TypeError, AttributeError): + # Try as a list + scope = scope[int(child)] # type: ignore + + try: + # This allows for custom falsy data types + # https://github.com/noahmorrison/chevron/issues/35 + if scope._CHEVRON_return_scope_when_falsy: # type: ignore + return scope + except AttributeError: + return scope or "" + except (AttributeError, KeyError, IndexError, ValueError): + # We couldn't find the key in the current scope + # We'll try again on the next pass + pass + + # We couldn't find the key in any of the scopes + + if warn: + logger.warn("Could not find key '%s'" % (key)) + + if keep: + return "%s %s %s" % (def_ldel, key, def_rdel) + + return "" + + +def _get_partial(name: str, partials_dict: Dict[str, str]) -> str: + """Load a partial""" + try: + # Maybe the partial is in the dictionary + return partials_dict[name] + except KeyError: + return "" + + +# +# The main rendering function +# +g_token_cache: Dict[str, List[Tuple[str, str]]] = {} + + +def render( + template: Union[str, List[Tuple[str, str]]] = "", + data: Dict[str, Any] = {}, + partials_dict: Dict[str, str] = {}, + padding: str = "", + def_ldel: str = "{{", + def_rdel: str = "}}", + scopes: Optional[Scopes] = None, + warn: bool = False, + keep: bool = False, +) -> str: + """Render a mustache template. + + Renders a mustache template with a data scope and inline partial capability. + + Arguments: + + template -- A file-like object or a string containing the template + + data -- A python dictionary with your data scope + + partials_path -- The path to where your partials are stored + If set to None, then partials won't be loaded from the file system + (defaults to '.') + + partials_ext -- The extension that you want the parser to look for + (defaults to 'mustache') + + partials_dict -- A python dictionary which will be search for partials + before the filesystem is. {'include': 'foo'} is the same + as a file called include.mustache + (defaults to {}) + + padding -- This is for padding partials, and shouldn't be used + (but can be if you really want to) + + def_ldel -- The default left delimiter + ("{{" by default, as in spec compliant mustache) + + def_rdel -- The default right delimiter + ("}}" by default, as in spec compliant mustache) + + scopes -- The list of scopes that get_key will look through + + warn -- Log a warning when a template substitution isn't found in the data + + keep -- Keep unreplaced tags when a substitution isn't found in the data + + + Returns: + + A string containing the rendered template. + """ + + # If the template is a sequence but not derived from a string + if isinstance(template, Sequence) and not isinstance(template, str): + # Then we don't need to tokenize it + # But it does need to be a generator + tokens: Iterator[Tuple[str, str]] = (token for token in template) + else: + if template in g_token_cache: + tokens = (token for token in g_token_cache[template]) + else: + # Otherwise make a generator + tokens = tokenize(template, def_ldel, def_rdel) + + output = "" + + if scopes is None: + scopes = [data] + + # Run through the tokens + for tag, key in tokens: + # Set the current scope + current_scope = scopes[0] + + # If we're an end tag + if tag == "end": + # Pop out of the latest scope + del scopes[0] + + # If the current scope is falsy and not the only scope + elif not current_scope and len(scopes) != 1: + if tag in ["section", "inverted section"]: + # Set the most recent scope to a falsy value + scopes.insert(0, False) + + # If we're a literal tag + elif tag == "literal": + # Add padding to the key and add it to the output + output += key.replace("\n", "\n" + padding) + + # If we're a variable tag + elif tag == "variable": + # Add the html escaped key to the output + thing = _get_key( + key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel + ) + if thing is True and key == ".": + # if we've coerced into a boolean by accident + # (inverted tags do this) + # then get the un-coerced object (next in the stack) + thing = scopes[1] + if not isinstance(thing, str): + thing = str(thing) + output += _html_escape(thing) + + # If we're a no html escape tag + elif tag == "no escape": + # Just lookup the key and add it + thing = _get_key( + key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel + ) + if not isinstance(thing, str): + thing = str(thing) + output += thing + + # If we're a section tag + elif tag == "section": + # Get the sections scope + scope = _get_key( + key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel + ) + + # If the scope is a callable (as described in + # https://mustache.github.io/mustache.5.html) + if callable(scope): + # Generate template text from tags + text = "" + tags: List[Tuple[str, str]] = [] + for token in tokens: + if token == ("end", key): + break + + tags.append(token) + tag_type, tag_key = token + if tag_type == "literal": + text += tag_key + elif tag_type == "no escape": + text += "%s& %s %s" % (def_ldel, tag_key, def_rdel) + else: + text += "%s%s %s%s" % ( + def_ldel, + { + "comment": "!", + "section": "#", + "inverted section": "^", + "end": "/", + "partial": ">", + "set delimiter": "=", + "no escape": "&", + "variable": "", + }[tag_type], + tag_key, + def_rdel, + ) + + g_token_cache[text] = tags + + rend = scope( + text, + lambda template, data=None: render( + template, + data={}, + partials_dict=partials_dict, + padding=padding, + def_ldel=def_ldel, + def_rdel=def_rdel, + scopes=data and [data] + scopes or scopes, + warn=warn, + keep=keep, + ), + ) + + output += rend + + # If the scope is a sequence, an iterator or generator but not + # derived from a string + elif isinstance(scope, (Sequence, Iterator)) and not isinstance(scope, str): + # Then we need to do some looping + + # Gather up all the tags inside the section + # (And don't be tricked by nested end tags with the same key) + # TODO: This feels like it still has edge cases, no? + tags = [] + tags_with_same_key = 0 + for token in tokens: + if token == ("section", key): + tags_with_same_key += 1 + if token == ("end", key): + tags_with_same_key -= 1 + if tags_with_same_key < 0: + break + tags.append(token) + + # For every item in the scope + for thing in scope: + # Append it as the most recent scope and render + new_scope = [thing] + scopes + rend = render( + template=tags, + scopes=new_scope, + padding=padding, + partials_dict=partials_dict, + def_ldel=def_ldel, + def_rdel=def_rdel, + warn=warn, + keep=keep, + ) + + output += rend + + else: + # Otherwise we're just a scope section + scopes.insert(0, scope) + + # If we're an inverted section + elif tag == "inverted section": + # Add the flipped scope to the scopes + scope = _get_key( + key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel + ) + scopes.insert(0, cast(Literal[False], not scope)) + + # If we're a partial + elif tag == "partial": + # Load the partial + partial = _get_partial(key, partials_dict) + + # Find what to pad the partial with + left = output.rpartition("\n")[2] + part_padding = padding + if left.isspace(): + part_padding += left + + # Render the partial + part_out = render( + template=partial, + partials_dict=partials_dict, + def_ldel=def_ldel, + def_rdel=def_rdel, + padding=part_padding, + scopes=scopes, + warn=warn, + keep=keep, + ) + + # If the partial was indented + if left.isspace(): + # then remove the spaces from the end + part_out = part_out.rstrip(" \t") + + # Add the partials output to the output + output += part_out + + return output diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 67bd4209db..2cb19695e4 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -191,6 +191,34 @@ async def test_chat_prompt_template_from_messages_using_role_strings() -> None: assert messages == expected +def test_chat_prompt_template_from_messages_mustache() -> None: + """Test creating a chat prompt template from role string messages.""" + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are a helpful AI bot. Your name is {{name}}."), + ("human", "Hello, how are you doing?"), + ("ai", "I'm doing well, thanks!"), + ("human", "{{user_input}}"), + ], + "mustache", + ) + + messages = template.format_messages(name="Bob", user_input="What is your name?") + + assert messages == [ + SystemMessage( + content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={} + ), + HumanMessage( + content="Hello, how are you doing?", additional_kwargs={}, example=False + ), + AIMessage( + content="I'm doing well, thanks!", additional_kwargs={}, example=False + ), + HumanMessage(content="What is your name?", additional_kwargs={}, example=False), + ] + + def test_chat_prompt_template_with_messages( messages: List[BaseMessagePromptTemplate], ) -> None: diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 2a62872446..b68316986b 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -38,6 +38,135 @@ def test_prompt_from_template() -> None: assert prompt == expected_prompt +def test_mustache_prompt_from_template() -> None: + """Test prompts can be constructed from a template.""" + # Single input variable. + template = "This is a {{foo}} test." + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(foo="bar") == "This is a bar test." + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"title": "Foo", "type": "string"}}, + } + + # Multiple input variables. + template = "This {{bar}} is a {{foo}} test." + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test." + assert prompt.input_variables == ["bar", "foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "foo": {"title": "Foo", "type": "string"}, + }, + } + + # Multiple input variables with repeats. + template = "This {{bar}} is a {{foo}} test {{foo}}." + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar." + assert prompt.input_variables == ["bar", "foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "foo": {"title": "Foo", "type": "string"}, + }, + } + + # Nested variables. + template = "This {{obj.bar}} is a {{obj.foo}} test {{foo}}." + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(obj={"bar": "foo", "foo": "bar"}, foo="baz") == ( + "This foo is a bar test baz." + ) + assert prompt.input_variables == ["foo", "obj"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "foo": {"title": "Foo", "type": "string"}, + "obj": {"$ref": "#/definitions/obj"}, + }, + "definitions": { + "obj": { + "title": "obj", + "type": "object", + "properties": { + "foo": {"title": "Foo", "type": "string"}, + "bar": {"title": "Bar", "type": "string"}, + }, + } + }, + } + + # . variables + template = "This {{.}} is a test." + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.") + assert prompt.input_variables == [] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {}, + } + + # section/context variables + template = """This{{#foo}} + {{bar}} + {{/foo}}is a test.""" + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(foo={"bar": "yo"}) == ( + """This + yo + is a test.""" + ) + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + "definitions": { + "foo": { + "title": "foo", + "type": "object", + "properties": {"bar": {"title": "Bar", "type": "string"}}, + } + }, + } + + # section/context variables with repeats + template = """This{{#foo}} + {{bar}} + {{/foo}}is a test.""" + prompt = PromptTemplate.from_template(template, template_format="mustache") + assert prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}]) == ( + """This + yo + + hello + is a test.""" + ) + assert prompt.input_variables == ["foo"] + assert prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": {"foo": {"$ref": "#/definitions/foo"}}, + "definitions": { + "foo": { + "title": "foo", + "type": "object", + "properties": {"bar": {"title": "Bar", "type": "string"}}, + } + }, + } + + def test_prompt_from_template_with_partial_variables() -> None: """Test prompts can be constructed from a template with partial variables.""" # given