From a2b699dcd2019a77ae970847c0af9fb96b9eef70 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 4 Feb 2023 17:04:58 -0800 Subject: [PATCH] prompt template from string (#884) --- .../prompts/examples/prompt_management.ipynb | 41 +++++++++++++++++++ langchain/chains/llm.py | 8 +--- langchain/prompts/prompt.py | 9 ++++ tests/unit_tests/prompts/test_prompt.py | 21 ++++++++++ 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/docs/modules/prompts/examples/prompt_management.ipynb b/docs/modules/prompts/examples/prompt_management.ipynb index e6c3c5b0..de65d747 100644 --- a/docs/modules/prompts/examples/prompt_management.ipynb +++ b/docs/modules/prompts/examples/prompt_management.ipynb @@ -151,6 +151,47 @@ "multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")" ] }, + { + "cell_type": "markdown", + "id": "72f32ff2", + "metadata": {}, + "source": [ + "## From Template\n", + "You can also easily load a prompt template by just specifying the template, and not worrying about the input variables." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2a81f2f8", + "metadata": {}, + "outputs": [], + "source": [ + "template = \"Tell me a {adjective} joke about {content}.\"\n", + "multiple_input_prompt = PromptTemplate.from_template(template)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d365b144", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PromptTemplate(input_variables=['adjective', 'content'], output_parser=None, template='Tell me a {adjective} joke about {content}.', template_format='f-string', validate_template=True)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "multiple_input_prompt" + ] + }, { "cell_type": "markdown", "id": "b2dd6154", diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 33174bf4..5edf5fc1 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -1,5 +1,4 @@ """Chain that just formats a prompt and calls an LLM.""" -from string import Formatter from typing import Any, Dict, List, Sequence, Union from pydantic import BaseModel, Extra @@ -132,10 +131,5 @@ class LLMChain(Chain, BaseModel): @classmethod def from_string(cls, llm: BaseLLM, template: str) -> Chain: """Create LLMChain from LLM and template.""" - input_variables = { - v for _, v, _, _ in Formatter().parse(template) if v is not None - } - prompt_template = PromptTemplate( - input_variables=list(input_variables), template=template - ) + prompt_template = PromptTemplate.from_template(template) return cls(llm=llm, prompt=prompt_template) diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index eed015f5..98a145cf 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -1,6 +1,7 @@ """Prompt schema definition.""" from __future__ import annotations +from string import Formatter from typing import Any, Dict, List from pydantic import BaseModel, Extra, root_validator @@ -117,6 +118,14 @@ class PromptTemplate(BasePromptTemplate, BaseModel): template = f.read() return cls(input_variables=input_variables, template=template) + @classmethod + def from_template(cls, template: str) -> PromptTemplate: + """Load a prompt template from a template.""" + input_variables = { + v for _, v, _, _ in Formatter().parse(template) if v is not None + } + return cls(input_variables=list(input_variables), template=template) + # For backwards compatibility. Prompt = PromptTemplate diff --git a/tests/unit_tests/prompts/test_prompt.py b/tests/unit_tests/prompts/test_prompt.py index cec597e0..1789963e 100644 --- a/tests/unit_tests/prompts/test_prompt.py +++ b/tests/unit_tests/prompts/test_prompt.py @@ -13,6 +13,27 @@ def test_prompt_valid() -> None: assert prompt.input_variables == input_variables +def test_prompt_from_template() -> None: + """Test prompts can be constructed from a template.""" + # Single input variable. + template = "This is a {foo} test." + prompt = PromptTemplate.from_template(template) + expected_prompt = PromptTemplate(template=template, input_variables=["foo"]) + assert prompt == expected_prompt + + # Multiple input variables. + template = "This {bar} is a {foo} test." + prompt = PromptTemplate.from_template(template) + expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"]) + assert prompt == expected_prompt + + # Multiple input variables with repeats. + template = "This {bar} is a {foo} test {foo}." + prompt = PromptTemplate.from_template(template) + expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"]) + assert prompt == expected_prompt + + def test_prompt_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {foo} test."