Make prompt validation opt-in (#11973)

By default replace input_variables with the correct value

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11984/head
Nuno Campos 10 months ago committed by GitHub
commit 6bd9c1d2b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
from langchain.schema.messages import BaseMessage, HumanMessage
@ -99,6 +100,20 @@ def check_valid_template(
)
def get_template_variables(template: str, template_format: str) -> List[str]:
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""

@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""List of input variables in template messages. Used for validation."""
messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates.
@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values:
if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]):
raise ValueError(
"Got mismatched input_variables. "

@ -2,12 +2,13 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
get_template_variables,
)
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return whether or not the class is serializable."""
return False
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
input_variables: List[str]
@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
values["template_format"],
values["input_variables"] + list(values["partial_variables"]),
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
class Config:

@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
@root_validator(pre=True)
@ -72,6 +72,12 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}"
)
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
return values
class Config:

@ -2,14 +2,13 @@
from __future__ import annotations
from pathlib import Path
from string import Formatter
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Union
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template,
get_template_variables,
)
from langchain.pydantic_v1 import root_validator
@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate):
template: str
"""The prompt template."""
template_format: str = "f-string"
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate:
@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate):
check_valid_template(
values["template"], values["template_format"], all_inputs
)
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values
@classmethod
@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate):
Returns:
The prompt template loaded from the template.
"""
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {}
if _partial_variables:
input_variables = {
input_variables = [
var for var in input_variables if var not in _partial_variables
}
]
return cls(
input_variables=sorted(input_variables),
input_variables=input_variables,
template=template,
template_format=template_format,
partial_variables=_partial_variables,

@ -97,11 +97,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
values["llm_chain"] = LLMChain(
llm=values.get("llm"),
prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"]
template=QUERY_CHECKER, input_variables=["dialect", "query"]
),
)
if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
)

@ -15,6 +15,5 @@
"partial_variables": {},
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
"template_format": "f-string",
"validate_template": true,
"_type": "prompt"
}
}

@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=["foo"])
ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
== []
)
def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=[])
ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
).input_variables == ["foo"]
def test_infer_variables() -> None:

@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
# Test when missing in prefix
template = "This is a {foo} test."
@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_few_shot_functionality() -> None:
@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar"]
# Test when missing in prefix
with pytest.warns(UserWarning):
@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
@pytest.mark.requires("jinja2")
@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
validate_template=True,
)
assert FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
def test_few_shot_chat_message_prompt_template() -> None:

@ -1,5 +1,7 @@
"""Test few shot prompt template."""
import pytest
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.prompt import PromptTemplate
@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
"Now you try to talk about party."
)
assert output == expected_output
def test_prompttemplate_validation() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
)
suffix = PromptTemplate(
input_variables=["new_content"],
template="Now you try to talk about {new_content}.",
)
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
with pytest.raises(ValueError):
FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
validate_template=True,
)
assert FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
).input_variables == ["content", "new_content"]

@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None:
template = "This is a {foo} test."
input_variables: list = []
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None:
@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_wrong_input_variables() -> None:
@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None:
template = "This is a {foo} test."
input_variables = ["bar"]
with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template)
PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_from_examples_valid() -> None:
@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None:
input_variables: list = []
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2")
@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None:
input_variables = ["foo", "bar"]
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2")
@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
input_variables = ["bar"]
with pytest.warns(UserWarning):
PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
)
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]

Loading…
Cancel
Save