PromptTemplate update documentation and expand kwargs (#8234)

# PromptTemplate

* Update documentation to highlight the classmethod for instantiating a
prompt template.
* Expand kwargs in the classmethod to make parameters easier to discover
pull/8395/head
Eugene Yurtsev 12 months ago committed by GitHub
parent a003a0baf6
commit 6dd18eee26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union
from pydantic import root_validator
@ -16,12 +16,24 @@ from langchain.prompts.base import (
class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM.
"""A prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
Example:
.. code-block:: python
from langchain import PromptTemplate
# Instantiation using from_template (recommended)
prompt = PromptTemplate.from_template("Say {foo}")
prompt.format(foo="bar")
# Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
"""
@ -44,6 +56,7 @@ class PromptTemplate(StringPromptTemplate):
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
if isinstance(other, PromptTemplate):
if self.template_format != "f-string":
@ -95,9 +108,9 @@ class PromptTemplate(StringPromptTemplate):
Example:
.. code-block:: python
.. code-block:: python
prompt.format(variable1="foo")
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@ -153,6 +166,7 @@ class PromptTemplate(StringPromptTemplate):
template_file: The path to the file containing the prompt template.
input_variables: A list of variable names the final prompt template
will expect.
Returns:
The prompt loaded from the file.
"""
@ -161,25 +175,52 @@ class PromptTemplate(StringPromptTemplate):
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
def from_template(
cls,
template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt template from a template.
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
Returns:
The prompt template loaded from the template.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
else:
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
_partial_variables = partial_variables or {}
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
if _partial_variables:
input_variables = {
var for var in input_variables if var not in partial_variables
var for var in input_variables if var not in _partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
input_variables=sorted(input_variables),
template=template,
template_format=template_format,
partial_variables=_partial_variables,
**kwargs,
)

@ -98,7 +98,8 @@
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
"template_format": "f-string",
"partial_variables": {}
}
}
}
@ -176,7 +177,8 @@
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
"template_format": "f-string",
"partial_variables": {}
}
}
}
@ -245,7 +247,8 @@
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
"template_format": "f-string",
"partial_variables": {}
}
}
}
@ -277,3 +280,25 @@
}
'''
# ---
# name: test_serialize_prompt
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string",
"partial_variables": {}
}
}
'''
# ---

@ -129,6 +129,12 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None:
del os.environ["OPENAI_API_KEY"]
def test_serialize_prompt(snapshot: Any) -> None:
"""Test that prompt is serialized correctly"""
prompt = PromptTemplate.from_template("hello {name}!")
assert dumps(prompt, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI(

@ -161,6 +161,10 @@ Will it get confused{ }?
)
assert prompt == expected_prompt
@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template_multiple_inputs() -> None:
"""Test with multiple input variables."""
# Multiple input variables.
template = """\
Hello world
@ -186,7 +190,10 @@ You just set bar boolean variable to true
assert prompt == expected_prompt
# Multiple input variables with repeats.
@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template_multiple_inputs_with_repeats() -> None:
"""Test with multiple input variables and repeats."""
template = """\
Hello world

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save