mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
prompt template from string (#884)
This commit is contained in:
parent
7cc44b3bdb
commit
a2b699dcd2
@ -151,6 +151,47 @@
|
||||
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "72f32ff2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## From Template\n",
|
||||
"You can also easily load a prompt template by just specifying the template, and not worrying about the input variables."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2a81f2f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"Tell me a {adjective} joke about {content}.\"\n",
|
||||
"multiple_input_prompt = PromptTemplate.from_template(template)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "d365b144",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"PromptTemplate(input_variables=['adjective', 'content'], output_parser=None, template='Tell me a {adjective} joke about {content}.', template_format='f-string', validate_template=True)"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"multiple_input_prompt"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b2dd6154",
|
||||
|
@ -1,5 +1,4 @@
|
||||
"""Chain that just formats a prompt and calls an LLM."""
|
||||
from string import Formatter
|
||||
from typing import Any, Dict, List, Sequence, Union
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
@ -132,10 +131,5 @@ class LLMChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
prompt_template = PromptTemplate(
|
||||
input_variables=list(input_variables), template=template
|
||||
)
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Prompt schema definition."""
|
||||
from __future__ import annotations
|
||||
|
||||
from string import Formatter
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
@ -117,6 +118,14 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
||||
template = f.read()
|
||||
return cls(input_variables=input_variables, template=template)
|
||||
|
||||
@classmethod
|
||||
def from_template(cls, template: str) -> PromptTemplate:
|
||||
"""Load a prompt template from a template."""
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
return cls(input_variables=list(input_variables), template=template)
|
||||
|
||||
|
||||
# For backwards compatibility.
|
||||
Prompt = PromptTemplate
|
||||
|
@ -13,6 +13,27 @@ def test_prompt_valid() -> None:
|
||||
assert prompt.input_variables == input_variables
|
||||
|
||||
|
||||
def test_prompt_from_template() -> None:
|
||||
"""Test prompts can be constructed from a template."""
|
||||
# Single input variable.
|
||||
template = "This is a {foo} test."
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
expected_prompt = PromptTemplate(template=template, input_variables=["foo"])
|
||||
assert prompt == expected_prompt
|
||||
|
||||
# Multiple input variables.
|
||||
template = "This {bar} is a {foo} test."
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
|
||||
assert prompt == expected_prompt
|
||||
|
||||
# Multiple input variables with repeats.
|
||||
template = "This {bar} is a {foo} test {foo}."
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
|
||||
assert prompt == expected_prompt
|
||||
|
||||
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
|
Loading…
Reference in New Issue
Block a user