Make prompt validation opt-in

By default replace input_variables with the correct value
This commit is contained in:
Nuno Campos 2023-10-18 10:46:22 +01:00
parent 202acce0c9
commit b753bf3323
10 changed files with 191 additions and 32 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
from langchain.schema.messages import BaseMessage, HumanMessage
@ -99,6 +100,20 @@ def check_valid_template(
)
def get_template_variables(template: str, template_format: str) -> List[str]:
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""

View File

@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values:
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "

View File

@ -2,12 +2,13 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return whether or not the class is serializable."""
return False
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
input_variables: List[str]
@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
class Config:

View File

@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
@ -72,6 +72,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
print(values["input_variables"])
return values
class Config:

View File

@ -2,14 +2,13 @@
from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template,
get_template_variables,
)
from langchain.pydantic_v1 import root_validator
@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate):
template: str
"""The prompt template."""
template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate):
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod
@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate):
Returns:
The prompt template loaded from the template.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}
if _partial_variables:
input_variables = {
input_variables = [
var for var in input_variables if var not in _partial_variables
}
]
return cls(
input_variables=sorted(input_variables),
input_variables=input_variables,
template=template,
template_format=template_format,
partial_variables=_partial_variables,

View File

@ -15,6 +15,5 @@
"partial_variables": {},
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
"template_format": "f-string",
"validate_template": true,
"_type": "prompt"
}

View File

@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=["foo"])
ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
== []
)
def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=[])
ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
).input_variables == ["foo"]
def test_infer_variables() -> None:

View File

@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
# Test when missing in prefix
template = "This is a {foo} test."
@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_few_shot_functionality() -> None:
@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar"]
# Test when missing in prefix
with pytest.warns(UserWarning):
@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
@pytest.mark.requires("jinja2")
@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
def test_few_shot_chat_message_prompt_template() -> None:

View File

@ -1,5 +1,7 @@
"""Test few shot prompt template."""
import pytest
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.prompt import PromptTemplate
@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
"Now you try to talk about party."
)
assert output == expected_output
def test_prompttemplate_validation() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
)
suffix = PromptTemplate(
input_variables=["new_content"],
template="Now you try to talk about {new_content}.",
)
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
with pytest.raises(ValueError):
FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
validate_template=True,
)
assert FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
).input_variables == ["content", "new_content"]

View File

@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None:
template = "This is a {foo} test."
input_variables: list = []
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_wrong_input_variables() -> None:
@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_from_examples_valid() -> None:
@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None:
input_variables: list = []
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2")
@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None:
input_variables = ["foo", "bar"]
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2")
@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
input_variables = ["bar"]
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]