mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Make prompt validation opt-in
By default replace input_variables with the correct value
This commit is contained in:
parent
202acce0c9
commit
b753bf3323
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from string import Formatter
|
||||
from typing import Any, Callable, Dict, List, Literal, Set
|
||||
|
||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||
@ -99,6 +100,20 @@ def check_valid_template(
|
||||
)
|
||||
|
||||
|
||||
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
elif template_format == "f-string":
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
|
||||
return sorted(input_variables)
|
||||
|
||||
|
||||
class StringPromptValue(PromptValue):
|
||||
"""String prompt value."""
|
||||
|
||||
|
@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
"""List of input variables in template messages. Used for validation."""
|
||||
messages: List[MessageLike]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
"""Combine two prompt templates.
|
||||
@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
||||
input_types[message.variable_name] = List[AnyMessage]
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
if "input_variables" in values:
|
||||
if "input_variables" in values and values.get("validate_template"):
|
||||
if input_vars != set(values["input_variables"]):
|
||||
raise ValueError(
|
||||
"Got mismatched input_variables. "
|
||||
|
@ -2,12 +2,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Return whether or not the class is serializable."""
|
||||
return False
|
||||
|
||||
validate_template: bool = True
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
input_variables: List[str]
|
||||
@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
prefix: str = ""
|
||||
"""A prompt template string to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator()
|
||||
@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
values["template_format"],
|
||||
values["input_variables"] + list(values["partial_variables"]),
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["prefix"] + values["suffix"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
class Config:
|
||||
|
@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = True
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@ -72,6 +72,13 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
f"Got input_variables={input_variables}, but based on "
|
||||
f"prefix/suffix expected {expected_input_variables}"
|
||||
)
|
||||
else:
|
||||
values["input_variables"] = sorted(
|
||||
set(values["suffix"].input_variables)
|
||||
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||
- set(values["partial_variables"])
|
||||
)
|
||||
print(values["input_variables"])
|
||||
return values
|
||||
|
||||
class Config:
|
||||
|
@ -2,14 +2,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from langchain.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
_get_jinja2_variables_from_template,
|
||||
check_valid_template,
|
||||
get_template_variables,
|
||||
)
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
|
||||
@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate):
|
||||
template: str
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = True
|
||||
validate_template: bool = False
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
def __add__(self, other: Any) -> PromptTemplate:
|
||||
@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate):
|
||||
check_valid_template(
|
||||
values["template"], values["template_format"], all_inputs
|
||||
)
|
||||
elif values.get("template_format"):
|
||||
values["input_variables"] = [
|
||||
var
|
||||
for var in get_template_variables(
|
||||
values["template"], values["template_format"]
|
||||
)
|
||||
if var not in values["partial_variables"]
|
||||
]
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate):
|
||||
Returns:
|
||||
The prompt template loaded from the template.
|
||||
"""
|
||||
if template_format == "jinja2":
|
||||
# Get the variables for the template
|
||||
input_variables = _get_jinja2_variables_from_template(template)
|
||||
elif template_format == "f-string":
|
||||
input_variables = {
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported template format: {template_format}")
|
||||
|
||||
input_variables = get_template_variables(template, template_format)
|
||||
_partial_variables = partial_variables or {}
|
||||
|
||||
if _partial_variables:
|
||||
input_variables = {
|
||||
input_variables = [
|
||||
var for var in input_variables if var not in _partial_variables
|
||||
}
|
||||
]
|
||||
|
||||
return cls(
|
||||
input_variables=sorted(input_variables),
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format=template_format,
|
||||
partial_variables=_partial_variables,
|
||||
|
@ -15,6 +15,5 @@
|
||||
"partial_variables": {},
|
||||
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
|
||||
"template_format": "f-string",
|
||||
"validate_template": true,
|
||||
"_type": "prompt"
|
||||
}
|
@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
|
||||
def test_chat_invalid_input_variables_extra() -> None:
|
||||
messages = [HumanMessage(content="foo")]
|
||||
with pytest.raises(ValueError):
|
||||
ChatPromptTemplate(messages=messages, input_variables=["foo"])
|
||||
ChatPromptTemplate(
|
||||
messages=messages, input_variables=["foo"], validate_template=True
|
||||
)
|
||||
assert (
|
||||
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_chat_invalid_input_variables_missing() -> None:
|
||||
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
||||
with pytest.raises(ValueError):
|
||||
ChatPromptTemplate(messages=messages, input_variables=[])
|
||||
ChatPromptTemplate(
|
||||
messages=messages, input_variables=[], validate_template=True
|
||||
)
|
||||
assert ChatPromptTemplate(
|
||||
messages=messages, input_variables=[]
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_infer_variables() -> None:
|
||||
|
@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
).input_variables == ["foo"]
|
||||
|
||||
# Test when missing in prefix
|
||||
template = "This is a {foo} test."
|
||||
@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
|
||||
examples=[],
|
||||
prefix=template,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix="foo",
|
||||
examples=[],
|
||||
prefix=template,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=input_variables,
|
||||
suffix=template,
|
||||
examples=[],
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_few_shot_functionality() -> None:
|
||||
@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=[],
|
||||
suffix=suffix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
).input_variables == ["bar"]
|
||||
|
||||
# Test when missing in prefix
|
||||
with pytest.warns(UserWarning):
|
||||
@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=["bar"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
).input_variables == ["bar", "foo"]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptTemplate(
|
||||
input_variables=["bar", "foo", "extra", "thing"],
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
examples=example_jinja2_prompt[1],
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
).input_variables == ["bar", "foo"]
|
||||
|
||||
|
||||
def test_few_shot_chat_message_prompt_template() -> None:
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Test few shot prompt template."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
|
||||
"Now you try to talk about party."
|
||||
)
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_prompttemplate_validation() -> None:
|
||||
"""Test that few shot works when prefix and suffix are PromptTemplates."""
|
||||
prefix = PromptTemplate(
|
||||
input_variables=["content"], template="This is a test about {content}."
|
||||
)
|
||||
suffix = PromptTemplate(
|
||||
input_variables=["new_content"],
|
||||
template="Now you try to talk about {new_content}.",
|
||||
)
|
||||
|
||||
examples = [
|
||||
{"question": "foo", "answer": "bar"},
|
||||
{"question": "baz", "answer": "foo"},
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
FewShotPromptWithTemplates(
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
input_variables=[],
|
||||
examples=examples,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
example_separator="\n",
|
||||
validate_template=True,
|
||||
)
|
||||
assert FewShotPromptWithTemplates(
|
||||
suffix=suffix,
|
||||
prefix=prefix,
|
||||
input_variables=[],
|
||||
examples=examples,
|
||||
example_prompt=EXAMPLE_PROMPT,
|
||||
example_separator="\n",
|
||||
).input_variables == ["content", "new_content"]
|
||||
|
@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_extra_input_variables() -> None:
|
||||
@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_wrong_input_variables() -> None:
|
||||
@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None:
|
||||
template = "This is a {foo} test."
|
||||
input_variables = ["bar"]
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(input_variables=input_variables, template=template)
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, validate_template=True
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
def test_prompt_from_examples_valid() -> None:
|
||||
@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None:
|
||||
input_variables: list = []
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None:
|
||||
input_variables = ["foo", "bar"]
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
).input_variables == ["foo"]
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
|
||||
input_variables = ["bar"]
|
||||
with pytest.warns(UserWarning):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format="jinja2",
|
||||
validate_template=True,
|
||||
)
|
||||
assert PromptTemplate(
|
||||
input_variables=input_variables, template=template, template_format="jinja2"
|
||||
).input_variables == ["foo"]
|
||||
|
Loading…
Reference in New Issue
Block a user