core: mustache prompt templates (#19980)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Nuno Campos 2024-04-10 11:25:32 -07:00 committed by GitHub
parent 4cb5f4c353
commit 15271ac832
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 904 additions and 12 deletions

View File

@ -8,6 +8,7 @@ from typing import (
Any, Any,
Dict, Dict,
List, List,
Literal,
Optional, Optional,
Sequence, Sequence,
Set, Set,
@ -929,6 +930,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def from_messages( def from_messages(
cls, cls,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
@ -964,7 +966,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
Returns: Returns:
a chat prompt template a chat prompt template
""" """
_messages = [_convert_to_message(message) for message in messages] _messages = [
_convert_to_message(message, template_format) for message in messages
]
# Automatically infer input variables from messages # Automatically infer input variables from messages
input_vars: Set[str] = set() input_vars: Set[str] = set()
@ -1121,7 +1125,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type( def _create_template_from_message_type(
message_type: str, template: Union[str, list] message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> BaseMessagePromptTemplate: ) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string. """Create a message prompt template from a message type and template string.
@ -1134,12 +1140,16 @@ def _create_template_from_message_type(
""" """
if message_type in ("human", "user"): if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template( message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template template, template_format=template_format
) )
elif message_type in ("ai", "assistant"): elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(cast(str, template)) message = AIMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "system": elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(cast(str, template)) message = SystemMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "placeholder": elif message_type == "placeholder":
if isinstance(template, str): if isinstance(template, str):
if template[0] != "{" or template[-1] != "}": if template[0] != "{" or template[-1] != "}":
@ -1180,6 +1190,7 @@ def _create_template_from_message_type(
def _convert_to_message( def _convert_to_message(
message: MessageLikeRepresentation, message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache"] = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]: ) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats. """Instantiate a message from a variety of message formats.
@ -1204,16 +1215,22 @@ def _convert_to_message(
elif isinstance(message, BaseMessage): elif isinstance(message, BaseMessage):
_message = message _message = message
elif isinstance(message, str): elif isinstance(message, str):
_message = _create_template_from_message_type("human", message) _message = _create_template_from_message_type(
"human", message, template_format=template_format
)
elif isinstance(message, tuple): elif isinstance(message, tuple):
if len(message) != 2: if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}") raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message message_type_str, template = message
if isinstance(message_type_str, str): if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template) _message = _create_template_from_message_type(
message_type_str, template, template_format=template_format
)
else: else:
_message = message_type_str( _message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template)) prompt=PromptTemplate.from_template(
cast(str, template), template_format=template_format
)
) )
else: else:
raise NotImplementedError(f"Unsupported message type: {type(message)}") raise NotImplementedError(f"Unsupported message type: {type(message)}")

View File

@ -10,8 +10,10 @@ from langchain_core.prompts.string import (
StringPromptTemplate, StringPromptTemplate,
check_valid_template, check_valid_template,
get_template_variables, get_template_variables,
mustache_schema,
) )
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig
class PromptTemplate(StringPromptTemplate): class PromptTemplate(StringPromptTemplate):
@ -65,12 +67,19 @@ class PromptTemplate(StringPromptTemplate):
template: str template: str
"""The prompt template.""" """The prompt template."""
template_format: Literal["f-string", "jinja2"] = "f-string" template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""
validate_template: bool = False validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache":
return super().get_input_schema(config)
return mustache_schema(self.template)
def __add__(self, other: Any) -> PromptTemplate: def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates.""" """Override the + operator to allow for combining prompt templates."""
# Allow for easy combining # Allow for easy combining
@ -121,6 +130,8 @@ class PromptTemplate(StringPromptTemplate):
def template_is_valid(cls, values: Dict) -> Dict: def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent.""" """Check that template and input variables are consistent."""
if values["validate_template"]: if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"]) all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template( check_valid_template(
values["template"], values["template_format"], all_inputs values["template"], values["template_format"], all_inputs

View File

@ -5,10 +5,12 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC from abc import ABC
from string import Formatter from string import Formatter
from typing import Any, Callable, Dict, List, Set from typing import Any, Callable, Dict, List, Set, Tuple, Type
import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]:
return variables return variables
def mustache_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using mustache."""
return mustache.render(template, kwargs)
def mustache_template_vars(
template: str,
) -> Set[str]:
"""Get the variables from a mustache template."""
vars: Set[str] = set()
in_section = False
for type, key in mustache.tokenize(template):
if type == "end":
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
vars.add(key.split(".")[0])
if type == "section":
in_section = True
return vars
Defs = Dict[str, "Defs"]
def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""Get the variables from a mustache template."""
fields = set()
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> Type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
for k, v in defs.items()
},
)
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format, "f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter, "jinja2": jinja2_formatter,
} }
@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
input_variables = { input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None v for _, v, _, _ in Formatter().parse(template) if v is not None
} }
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else: else:
raise ValueError(f"Unsupported template format: {template_format}") raise ValueError(f"Unsupported template format: {template_format}")

View File

@ -0,0 +1,641 @@
"""
Adapted from https://github.com/noahmorrison/chevron
MIT License
"""
import logging
from typing import (
Any,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
cast,
)
from typing_extensions import TypeAlias
logger = logging.getLogger(__name__)
Scopes: TypeAlias = List[Union[Literal[False, 0], Dict[str, Any]]]
# Globals
_CURRENT_LINE = 1
_LAST_TAG_LINE = None
class ChevronError(SyntaxError):
pass
#
# Helper functions
#
def grab_literal(template: str, l_del: str) -> Tuple[str, str]:
"""Parse a literal from the template"""
global _CURRENT_LINE
try:
# Look for the next tag and move the template to it
literal, template = template.split(l_del, 1)
_CURRENT_LINE += literal.count("\n")
return (literal, template)
# There are no more tags in the template?
except ValueError:
# Then the rest of the template is a literal
return (template, "")
def l_sa_check(template: str, literal: str, is_standalone: bool) -> bool:
"""Do a preliminary check to see if a tag could be a standalone"""
# If there is a newline, or the previous tag was a standalone
if literal.find("\n") != -1 or is_standalone:
padding = literal.split("\n")[-1]
# If all the characters since the last newline are spaces
if padding.isspace() or padding == "":
# Then the next tag could be a standalone
return True
else:
# Otherwise it can't be
return False
else:
return False
def r_sa_check(template: str, tag_type: str, is_standalone: bool) -> bool:
"""Do a final checkto see if a tag could be a standalone"""
# Check right side if we might be a standalone
if is_standalone and tag_type not in ["variable", "no escape"]:
on_newline = template.split("\n", 1)
# If the stuff to the right of us are spaces we're a standalone
if on_newline[0].isspace() or not on_newline[0]:
return True
else:
return False
# If we're a tag can't be a standalone
else:
return False
def parse_tag(template: str, l_del: str, r_del: str) -> Tuple[Tuple[str, str], str]:
"""Parse a tag from a template"""
global _CURRENT_LINE
global _LAST_TAG_LINE
tag_types = {
"!": "comment",
"#": "section",
"^": "inverted section",
"/": "end",
">": "partial",
"=": "set delimiter?",
"{": "no escape?",
"&": "no escape",
}
# Get the tag
try:
tag, template = template.split(r_del, 1)
except ValueError:
raise ChevronError("unclosed tag " "at line {0}".format(_CURRENT_LINE))
# Find the type meaning of the first character
tag_type = tag_types.get(tag[0], "variable")
# If the type is not a variable
if tag_type != "variable":
# Then that first character is not needed
tag = tag[1:]
# If we might be a set delimiter tag
if tag_type == "set delimiter?":
# Double check to make sure we are
if tag.endswith("="):
tag_type = "set delimiter"
# Remove the equal sign
tag = tag[:-1]
# Otherwise we should complain
else:
raise ChevronError(
"unclosed set delimiter tag\n" "at line {0}".format(_CURRENT_LINE)
)
# If we might be a no html escape tag
elif tag_type == "no escape?":
# And we have a third curly brace
# (And are using curly braces as delimiters)
if l_del == "{{" and r_del == "}}" and template.startswith("}"):
# Then we are a no html escape tag
template = template[1:]
tag_type = "no escape"
# Strip the whitespace off the key and return
return ((tag_type, tag.strip()), template)
#
# The main tokenizing function
#
def tokenize(
template: str, def_ldel: str = "{{", def_rdel: str = "}}"
) -> Iterator[Tuple[str, str]]:
"""Tokenize a mustache template
Tokenizes a mustache template in a generator fashion,
using file-like objects. It also accepts a string containing
the template.
Arguments:
template -- a file-like object, or a string of a mustache template
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache)
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
Returns:
A generator of mustache tags in the form of a tuple
-- (tag_type, tag_key)
Where tag_type is one of:
* literal
* section
* inverted section
* end
* partial
* no escape
And tag_key is either the key or in the case of a literal tag,
the literal itself.
"""
global _CURRENT_LINE, _LAST_TAG_LINE
_CURRENT_LINE = 1
_LAST_TAG_LINE = None
is_standalone = True
open_sections = []
l_del = def_ldel
r_del = def_rdel
while template:
literal, template = grab_literal(template, l_del)
# If the template is completed
if not template:
# Then yield the literal and leave
yield ("literal", literal)
break
# Do the first check to see if we could be a standalone
is_standalone = l_sa_check(template, literal, is_standalone)
# Parse the tag
tag, template = parse_tag(template, l_del, r_del)
tag_type, tag_key = tag
# Special tag logic
# If we are a set delimiter tag
if tag_type == "set delimiter":
# Then get and set the delimiters
dels = tag_key.strip().split(" ")
l_del, r_del = dels[0], dels[-1]
# If we are a section tag
elif tag_type in ["section", "inverted section"]:
# Then open a new section
open_sections.append(tag_key)
_LAST_TAG_LINE = _CURRENT_LINE
# If we are an end tag
elif tag_type == "end":
# Then check to see if the last opened section
# is the same as us
try:
last_section = open_sections.pop()
except IndexError:
raise ChevronError(
'Trying to close tag "{0}"\n'
"Looks like it was not opened.\n"
"line {1}".format(tag_key, _CURRENT_LINE + 1)
)
if tag_key != last_section:
# Otherwise we need to complain
raise ChevronError(
'Trying to close tag "{0}"\n'
'last open tag is "{1}"\n'
"line {2}".format(tag_key, last_section, _CURRENT_LINE + 1)
)
# Do the second check to see if we're a standalone
is_standalone = r_sa_check(template, tag_type, is_standalone)
# Which if we are
if is_standalone:
# Remove the stuff before the newline
template = template.split("\n", 1)[-1]
# Partials need to keep the spaces on their left
if tag_type != "partial":
# But other tags don't
literal = literal.rstrip(" ")
# Start yielding
# Ignore literals that are empty
if literal != "":
yield ("literal", literal)
# Ignore comments and set delimiters
if tag_type not in ["comment", "set delimiter?"]:
yield (tag_type, tag_key)
# If there are any open sections when we're done
if open_sections:
# Then we need to complain
raise ChevronError(
"Unexpected EOF\n"
'the tag "{0}" was never closed\n'
"was opened at line {1}".format(open_sections[-1], _LAST_TAG_LINE)
)
#
# Helper functions
#
def _html_escape(string: str) -> str:
"""HTML escape all of these " & < >"""
html_codes = {
'"': "&quot;",
"<": "&lt;",
">": "&gt;",
}
# & must be handled first
string = string.replace("&", "&amp;")
for char in html_codes:
string = string.replace(char, html_codes[char])
return string
def _get_key(
key: str,
scopes: Scopes,
warn: bool,
keep: bool,
def_ldel: str,
def_rdel: str,
) -> Any:
"""Get a key from the current scope"""
# If the key is a dot
if key == ".":
# Then just return the current scope
return scopes[0]
# Loop through the scopes
for scope in scopes:
try:
# Return an empty string if falsy, with two exceptions
# 0 should return 0, and False should return False
if scope in (0, False):
return scope
# For every dot separated key
for child in key.split("."):
# Return an empty string if falsy, with two exceptions
# 0 should return 0, and False should return False
if scope in (0, False):
return scope
# Move into the scope
try:
# Try subscripting (Normal dictionaries)
scope = cast(Dict[str, Any], scope)[child]
except (TypeError, AttributeError):
try:
scope = getattr(scope, child)
except (TypeError, AttributeError):
# Try as a list
scope = scope[int(child)] # type: ignore
try:
# This allows for custom falsy data types
# https://github.com/noahmorrison/chevron/issues/35
if scope._CHEVRON_return_scope_when_falsy: # type: ignore
return scope
except AttributeError:
return scope or ""
except (AttributeError, KeyError, IndexError, ValueError):
# We couldn't find the key in the current scope
# We'll try again on the next pass
pass
# We couldn't find the key in any of the scopes
if warn:
logger.warn("Could not find key '%s'" % (key))
if keep:
return "%s %s %s" % (def_ldel, key, def_rdel)
return ""
def _get_partial(name: str, partials_dict: Dict[str, str]) -> str:
"""Load a partial"""
try:
# Maybe the partial is in the dictionary
return partials_dict[name]
except KeyError:
return ""
#
# The main rendering function
#
g_token_cache: Dict[str, List[Tuple[str, str]]] = {}
def render(
template: Union[str, List[Tuple[str, str]]] = "",
data: Dict[str, Any] = {},
partials_dict: Dict[str, str] = {},
padding: str = "",
def_ldel: str = "{{",
def_rdel: str = "}}",
scopes: Optional[Scopes] = None,
warn: bool = False,
keep: bool = False,
) -> str:
"""Render a mustache template.
Renders a mustache template with a data scope and inline partial capability.
Arguments:
template -- A file-like object or a string containing the template
data -- A python dictionary with your data scope
partials_path -- The path to where your partials are stored
If set to None, then partials won't be loaded from the file system
(defaults to '.')
partials_ext -- The extension that you want the parser to look for
(defaults to 'mustache')
partials_dict -- A python dictionary which will be search for partials
before the filesystem is. {'include': 'foo'} is the same
as a file called include.mustache
(defaults to {})
padding -- This is for padding partials, and shouldn't be used
(but can be if you really want to)
def_ldel -- The default left delimiter
("{{" by default, as in spec compliant mustache)
def_rdel -- The default right delimiter
("}}" by default, as in spec compliant mustache)
scopes -- The list of scopes that get_key will look through
warn -- Log a warning when a template substitution isn't found in the data
keep -- Keep unreplaced tags when a substitution isn't found in the data
Returns:
A string containing the rendered template.
"""
# If the template is a sequence but not derived from a string
if isinstance(template, Sequence) and not isinstance(template, str):
# Then we don't need to tokenize it
# But it does need to be a generator
tokens: Iterator[Tuple[str, str]] = (token for token in template)
else:
if template in g_token_cache:
tokens = (token for token in g_token_cache[template])
else:
# Otherwise make a generator
tokens = tokenize(template, def_ldel, def_rdel)
output = ""
if scopes is None:
scopes = [data]
# Run through the tokens
for tag, key in tokens:
# Set the current scope
current_scope = scopes[0]
# If we're an end tag
if tag == "end":
# Pop out of the latest scope
del scopes[0]
# If the current scope is falsy and not the only scope
elif not current_scope and len(scopes) != 1:
if tag in ["section", "inverted section"]:
# Set the most recent scope to a falsy value
scopes.insert(0, False)
# If we're a literal tag
elif tag == "literal":
# Add padding to the key and add it to the output
output += key.replace("\n", "\n" + padding)
# If we're a variable tag
elif tag == "variable":
# Add the html escaped key to the output
thing = _get_key(
key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel
)
if thing is True and key == ".":
# if we've coerced into a boolean by accident
# (inverted tags do this)
# then get the un-coerced object (next in the stack)
thing = scopes[1]
if not isinstance(thing, str):
thing = str(thing)
output += _html_escape(thing)
# If we're a no html escape tag
elif tag == "no escape":
# Just lookup the key and add it
thing = _get_key(
key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel
)
if not isinstance(thing, str):
thing = str(thing)
output += thing
# If we're a section tag
elif tag == "section":
# Get the sections scope
scope = _get_key(
key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel
)
# If the scope is a callable (as described in
# https://mustache.github.io/mustache.5.html)
if callable(scope):
# Generate template text from tags
text = ""
tags: List[Tuple[str, str]] = []
for token in tokens:
if token == ("end", key):
break
tags.append(token)
tag_type, tag_key = token
if tag_type == "literal":
text += tag_key
elif tag_type == "no escape":
text += "%s& %s %s" % (def_ldel, tag_key, def_rdel)
else:
text += "%s%s %s%s" % (
def_ldel,
{
"comment": "!",
"section": "#",
"inverted section": "^",
"end": "/",
"partial": ">",
"set delimiter": "=",
"no escape": "&",
"variable": "",
}[tag_type],
tag_key,
def_rdel,
)
g_token_cache[text] = tags
rend = scope(
text,
lambda template, data=None: render(
template,
data={},
partials_dict=partials_dict,
padding=padding,
def_ldel=def_ldel,
def_rdel=def_rdel,
scopes=data and [data] + scopes or scopes,
warn=warn,
keep=keep,
),
)
output += rend
# If the scope is a sequence, an iterator or generator but not
# derived from a string
elif isinstance(scope, (Sequence, Iterator)) and not isinstance(scope, str):
# Then we need to do some looping
# Gather up all the tags inside the section
# (And don't be tricked by nested end tags with the same key)
# TODO: This feels like it still has edge cases, no?
tags = []
tags_with_same_key = 0
for token in tokens:
if token == ("section", key):
tags_with_same_key += 1
if token == ("end", key):
tags_with_same_key -= 1
if tags_with_same_key < 0:
break
tags.append(token)
# For every item in the scope
for thing in scope:
# Append it as the most recent scope and render
new_scope = [thing] + scopes
rend = render(
template=tags,
scopes=new_scope,
padding=padding,
partials_dict=partials_dict,
def_ldel=def_ldel,
def_rdel=def_rdel,
warn=warn,
keep=keep,
)
output += rend
else:
# Otherwise we're just a scope section
scopes.insert(0, scope)
# If we're an inverted section
elif tag == "inverted section":
# Add the flipped scope to the scopes
scope = _get_key(
key, scopes, warn=warn, keep=keep, def_ldel=def_ldel, def_rdel=def_rdel
)
scopes.insert(0, cast(Literal[False], not scope))
# If we're a partial
elif tag == "partial":
# Load the partial
partial = _get_partial(key, partials_dict)
# Find what to pad the partial with
left = output.rpartition("\n")[2]
part_padding = padding
if left.isspace():
part_padding += left
# Render the partial
part_out = render(
template=partial,
partials_dict=partials_dict,
def_ldel=def_ldel,
def_rdel=def_rdel,
padding=part_padding,
scopes=scopes,
warn=warn,
keep=keep,
)
# If the partial was indented
if left.isspace():
# then remove the spaces from the end
part_out = part_out.rstrip(" \t")
# Add the partials output to the output
output += part_out
return output

View File

@ -191,6 +191,34 @@ async def test_chat_prompt_template_from_messages_using_role_strings() -> None:
assert messages == expected assert messages == expected
def test_chat_prompt_template_from_messages_mustache() -> None:
"""Test creating a chat prompt template from role string messages."""
template = ChatPromptTemplate.from_messages(
[
("system", "You are a helpful AI bot. Your name is {{name}}."),
("human", "Hello, how are you doing?"),
("ai", "I'm doing well, thanks!"),
("human", "{{user_input}}"),
],
"mustache",
)
messages = template.format_messages(name="Bob", user_input="What is your name?")
assert messages == [
SystemMessage(
content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={}
),
HumanMessage(
content="Hello, how are you doing?", additional_kwargs={}, example=False
),
AIMessage(
content="I'm doing well, thanks!", additional_kwargs={}, example=False
),
HumanMessage(content="What is your name?", additional_kwargs={}, example=False),
]
def test_chat_prompt_template_with_messages( def test_chat_prompt_template_with_messages(
messages: List[BaseMessagePromptTemplate], messages: List[BaseMessagePromptTemplate],
) -> None: ) -> None:

View File

@ -38,6 +38,135 @@ def test_prompt_from_template() -> None:
assert prompt == expected_prompt assert prompt == expected_prompt
def test_mustache_prompt_from_template() -> None:
"""Test prompts can be constructed from a template."""
# Single input variable.
template = "This is a {{foo}} test."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="bar") == "This is a bar test."
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"title": "Foo", "type": "string"}},
}
# Multiple input variables.
template = "This {{bar}} is a {{foo}} test."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test."
assert prompt.input_variables == ["bar", "foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"foo": {"title": "Foo", "type": "string"},
},
}
# Multiple input variables with repeats.
template = "This {{bar}} is a {{foo}} test {{foo}}."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(bar="baz", foo="bar") == "This baz is a bar test bar."
assert prompt.input_variables == ["bar", "foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"bar": {"title": "Bar", "type": "string"},
"foo": {"title": "Foo", "type": "string"},
},
}
# Nested variables.
template = "This {{obj.bar}} is a {{obj.foo}} test {{foo}}."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(obj={"bar": "foo", "foo": "bar"}, foo="baz") == (
"This foo is a bar test baz."
)
assert prompt.input_variables == ["foo", "obj"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {
"foo": {"title": "Foo", "type": "string"},
"obj": {"$ref": "#/definitions/obj"},
},
"definitions": {
"obj": {
"title": "obj",
"type": "object",
"properties": {
"foo": {"title": "Foo", "type": "string"},
"bar": {"title": "Bar", "type": "string"},
},
}
},
}
# . variables
template = "This {{.}} is a test."
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo="baz") == ("This {'foo': 'baz'} is a test.")
assert prompt.input_variables == []
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {},
}
# section/context variables
template = """This{{#foo}}
{{bar}}
{{/foo}}is a test."""
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo={"bar": "yo"}) == (
"""This
yo
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
"foo": {
"title": "foo",
"type": "object",
"properties": {"bar": {"title": "Bar", "type": "string"}},
}
},
}
# section/context variables with repeats
template = """This{{#foo}}
{{bar}}
{{/foo}}is a test."""
prompt = PromptTemplate.from_template(template, template_format="mustache")
assert prompt.format(foo=[{"bar": "yo"}, {"bar": "hello"}]) == (
"""This
yo
hello
is a test."""
)
assert prompt.input_variables == ["foo"]
assert prompt.input_schema.schema() == {
"title": "PromptInput",
"type": "object",
"properties": {"foo": {"$ref": "#/definitions/foo"}},
"definitions": {
"foo": {
"title": "foo",
"type": "object",
"properties": {"bar": {"title": "Bar", "type": "string"}},
}
},
}
def test_prompt_from_template_with_partial_variables() -> None: def test_prompt_from_template_with_partial_variables() -> None:
"""Test prompts can be constructed from a template with partial variables.""" """Test prompts can be constructed from a template with partial variables."""
# given # given