PromptTemplate update documentation and expand kwargs (#8234)

# PromptTemplate

* Update documentation to highlight the classmethod for instantiating a
prompt template.
* Expand kwargs in the classmethod to make parameters easier to discover
pull/8395/head
Eugene Yurtsev 12 months ago committed by GitHub
parent a003a0baf6
commit 6dd18eee26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,7 @@ from __future__ import annotations
from pathlib import Path from pathlib import Path
from string import Formatter from string import Formatter
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Optional, Union
from pydantic import root_validator from pydantic import root_validator
@ -16,12 +16,24 @@ from langchain.prompts.base import (
class PromptTemplate(StringPromptTemplate): class PromptTemplate(StringPromptTemplate):
"""Schema to represent a prompt for an LLM. """A prompt template for a language model.
A prompt template consists of a string template. It accepts a set of parameters
from the user that can be used to generate a prompt for a language model.
The template can be formatted using either f-strings (default) or jinja2 syntax.
Example: Example:
.. code-block:: python .. code-block:: python
from langchain import PromptTemplate from langchain import PromptTemplate
# Instantiation using from_template (recommended)
prompt = PromptTemplate.from_template("Say {foo}")
prompt.format(foo="bar")
# Instantiation using initializer
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
""" """
@ -44,6 +56,7 @@ class PromptTemplate(StringPromptTemplate):
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate: def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining # Allow for easy combining
if isinstance(other, PromptTemplate): if isinstance(other, PromptTemplate):
if self.template_format != "f-string": if self.template_format != "f-string":
@ -95,9 +108,9 @@ class PromptTemplate(StringPromptTemplate):
Example: Example:
.. code-block:: python .. code-block:: python
prompt.format(variable1="foo") prompt.format(variable1="foo")
""" """
kwargs = self._merge_partial_and_user_variables(**kwargs) kwargs = self._merge_partial_and_user_variables(**kwargs)
return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs) return DEFAULT_FORMATTER_MAPPING[self.template_format](self.template, **kwargs)
@ -153,6 +166,7 @@ class PromptTemplate(StringPromptTemplate):
template_file: The path to the file containing the prompt template. template_file: The path to the file containing the prompt template.
input_variables: A list of variable names the final prompt template input_variables: A list of variable names the final prompt template
will expect. will expect.
Returns: Returns:
The prompt loaded from the file. The prompt loaded from the file.
""" """
@ -161,25 +175,52 @@ class PromptTemplate(StringPromptTemplate):
return cls(input_variables=input_variables, template=template, **kwargs) return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod @classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: def from_template(
"""Load a prompt template from a template.""" cls,
if "template_format" in kwargs and kwargs["template_format"] == "jinja2": template: str,
*,
template_format: str = "f-string",
partial_variables: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> PromptTemplate:
"""Load a prompt template from a template.
Args:
template: The template to load.
template_format: The format of the template. Use `jinja2` for jinja2,
and `f-string` or None for f-strings.
partial_variables: A dictionary of variables that can be used to partially
fill in the template. For example, if the template is
`"{variable1} {variable2}"`, and `partial_variables` is
`{"variable1": "foo"}`, then the final prompt will be
`"foo {variable2}"`.
Returns:
The prompt template loaded from the template.
"""
if template_format == "jinja2":
# Get the variables for the template # Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template) input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
else:
input_variables = { input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None v for _, v, _, _ in Formatter().parse(template) if v is not None
} }
else:
raise ValueError(f"Unsupported template format: {template_format}")
_partial_variables = partial_variables or {}
if "partial_variables" in kwargs: if _partial_variables:
partial_variables = kwargs["partial_variables"]
input_variables = { input_variables = {
var for var in input_variables if var not in partial_variables var for var in input_variables if var not in _partial_variables
} }
return cls( return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs input_variables=sorted(input_variables),
template=template,
template_format=template_format,
partial_variables=_partial_variables,
**kwargs,
) )

@ -98,7 +98,8 @@
"name" "name"
], ],
"template": "hello {name}!", "template": "hello {name}!",
"template_format": "f-string" "template_format": "f-string",
"partial_variables": {}
} }
} }
} }
@ -176,7 +177,8 @@
"name" "name"
], ],
"template": "hello {name}!", "template": "hello {name}!",
"template_format": "f-string" "template_format": "f-string",
"partial_variables": {}
} }
} }
} }
@ -245,7 +247,8 @@
"name" "name"
], ],
"template": "hello {name}!", "template": "hello {name}!",
"template_format": "f-string" "template_format": "f-string",
"partial_variables": {}
} }
} }
} }
@ -277,3 +280,25 @@
} }
''' '''
# --- # ---
# name: test_serialize_prompt
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string",
"partial_variables": {}
}
}
'''
# ---

@ -129,6 +129,12 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None:
del os.environ["OPENAI_API_KEY"] del os.environ["OPENAI_API_KEY"]
def test_serialize_prompt(snapshot: Any) -> None:
"""Test that prompt is serialized correctly"""
prompt = PromptTemplate.from_template("hello {name}!")
assert dumps(prompt, pretty=True) == snapshot
@pytest.mark.requires("openai") @pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None: def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI( llm = OpenAI(

@ -161,6 +161,10 @@ Will it get confused{ }?
) )
assert prompt == expected_prompt assert prompt == expected_prompt
@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template_multiple_inputs() -> None:
"""Test with multiple input variables."""
# Multiple input variables. # Multiple input variables.
template = """\ template = """\
Hello world Hello world
@ -186,7 +190,10 @@ You just set bar boolean variable to true
assert prompt == expected_prompt assert prompt == expected_prompt
# Multiple input variables with repeats.
@pytest.mark.requires("jinja2")
def test_prompt_from_jinja2_template_multiple_inputs_with_repeats() -> None:
"""Test with multiple input variables and repeats."""
template = """\ template = """\
Hello world Hello world

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save