prompt template from string (#884)

This commit is contained in:
Harrison Chase 2023-02-04 17:04:58 -08:00 committed by GitHub
parent 7cc44b3bdb
commit a2b699dcd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 7 deletions

View File

@ -151,6 +151,47 @@
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
]
},
{
"cell_type": "markdown",
"id": "72f32ff2",
"metadata": {},
"source": [
"## From Template\n",
"You can also easily load a prompt template by just specifying the template, and not worrying about the input variables."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a81f2f8",
"metadata": {},
"outputs": [],
"source": [
"template = \"Tell me a {adjective} joke about {content}.\"\n",
"multiple_input_prompt = PromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d365b144",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PromptTemplate(input_variables=['adjective', 'content'], output_parser=None, template='Tell me a {adjective} joke about {content}.', template_format='f-string', validate_template=True)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"multiple_input_prompt"
]
},
{
"cell_type": "markdown",
"id": "b2dd6154",

View File

@ -1,5 +1,4 @@
"""Chain that just formats a prompt and calls an LLM."""
from string import Formatter
from typing import Any, Dict, List, Sequence, Union
from pydantic import BaseModel, Extra
@ -132,10 +131,5 @@ class LLMChain(Chain, BaseModel):
@classmethod
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
"""Create LLMChain from LLM and template."""
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
prompt_template = PromptTemplate(
input_variables=list(input_variables), template=template
)
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)

View File

@ -1,6 +1,7 @@
"""Prompt schema definition."""
from __future__ import annotations
from string import Formatter
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
@ -117,6 +118,14 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
template = f.read()
return cls(input_variables=input_variables, template=template)
@classmethod
def from_template(cls, template: str) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
return cls(input_variables=list(input_variables), template=template)
# For backwards compatibility.
Prompt = PromptTemplate

View File

@ -13,6 +13,27 @@ def test_prompt_valid() -> None:
assert prompt.input_variables == input_variables
def test_prompt_from_template() -> None:
"""Test prompts can be constructed from a template."""
# Single input variable.
template = "This is a {foo} test."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["foo"])
assert prompt == expected_prompt
# Multiple input variables.
template = "This {bar} is a {foo} test."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
assert prompt == expected_prompt
# Multiple input variables with repeats.
template = "This {bar} is a {foo} test {foo}."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
assert prompt == expected_prompt
def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {foo} test."