Make prompt validation opt-in (#11973)

By default replace input_variables with the correct value

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-10-18 16:28:47 +01:00 committed by GitHub
commit 6bd9c1d2b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 192 additions and 34 deletions

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import warnings import warnings
from abc import ABC from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set from typing import Any, Callable, Dict, List, Literal, Set
from langchain.schema.messages import BaseMessage, HumanMessage from langchain.schema.messages import BaseMessage, HumanMessage
@ -99,6 +100,20 @@ def check_valid_template(
) )
def get_template_variables(template: str, template_format: str) -> List[str]:
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
return sorted(input_variables)
class StringPromptValue(PromptValue): class StringPromptValue(PromptValue):
"""String prompt value.""" """String prompt value."""

View File

@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
"""List of input variables in template messages. Used for validation.""" """List of input variables in template messages. Used for validation."""
messages: List[MessageLike] messages: List[MessageLike]
"""List of messages consisting of either message prompt templates or messages.""" """List of messages consisting of either message prompt templates or messages."""
validate_template: bool = False
"""Whether or not to try validating the template."""
def __add__(self, other: Any) -> ChatPromptTemplate: def __add__(self, other: Any) -> ChatPromptTemplate:
"""Combine two prompt templates. """Combine two prompt templates.
@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
input_types[message.variable_name] = List[AnyMessage] input_types[message.variable_name] = List[AnyMessage]
if "partial_variables" in values: if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"]) input_vars = input_vars - set(values["partial_variables"])
if "input_variables" in values: if "input_variables" in values and values.get("validate_template"):
if input_vars != set(values["input_variables"]): if input_vars != set(values["input_variables"]):
raise ValueError( raise ValueError(
"Got mismatched input_variables. " "Got mismatched input_variables. "

View File

@ -2,12 +2,13 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Literal, Optional, Union
from langchain.prompts.base import ( from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate, StringPromptTemplate,
check_valid_template, check_valid_template,
get_template_variables,
) )
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector from langchain.prompts.example_selector.base import BaseExampleSelector
@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Return whether or not the class is serializable.""" """Return whether or not the class is serializable."""
return False return False
validate_template: bool = True validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
input_variables: List[str] input_variables: List[str]
@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
prefix: str = "" prefix: str = ""
"""A prompt template string to put before the examples.""" """A prompt template string to put before the examples."""
template_format: str = "f-string" template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator() @root_validator()
@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
values["template_format"], values["template_format"],
values["input_variables"] + list(values["partial_variables"]), values["input_variables"] + list(values["partial_variables"]),
) )
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["prefix"] + values["suffix"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values return values
class Config: class Config:

View File

@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
template_format: str = "f-string" template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
@root_validator(pre=True) @root_validator(pre=True)
@ -72,6 +72,12 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
f"Got input_variables={input_variables}, but based on " f"Got input_variables={input_variables}, but based on "
f"prefix/suffix expected {expected_input_variables}" f"prefix/suffix expected {expected_input_variables}"
) )
else:
values["input_variables"] = sorted(
set(values["suffix"].input_variables)
| set(values["prefix"].input_variables if values["prefix"] else [])
- set(values["partial_variables"])
)
return values return values
class Config: class Config:

View File

@ -2,14 +2,13 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from string import Formatter from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Optional, Union
from langchain.prompts.base import ( from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING, DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate, StringPromptTemplate,
_get_jinja2_variables_from_template,
check_valid_template, check_valid_template,
get_template_variables,
) )
from langchain.pydantic_v1 import root_validator from langchain.pydantic_v1 import root_validator
@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate):
template: str template: str
"""The prompt template.""" """The prompt template."""
template_format: str = "f-string" template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True validate_template: bool = False
"""Whether or not to try validating the template.""" """Whether or not to try validating the template."""
def __add__(self, other: Any) -> PromptTemplate: def __add__(self, other: Any) -> PromptTemplate:
@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate):
check_valid_template( check_valid_template(
values["template"], values["template_format"], all_inputs values["template"], values["template_format"], all_inputs
) )
elif values.get("template_format"):
values["input_variables"] = [
var
for var in get_template_variables(
values["template"], values["template_format"]
)
if var not in values["partial_variables"]
]
return values return values
@classmethod @classmethod
@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate):
Returns: Returns:
The prompt template loaded from the template. The prompt template loaded from the template.
""" """
if template_format == "jinja2":
# Get the variables for the template
input_variables = _get_jinja2_variables_from_template(template)
elif template_format == "f-string":
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
else:
raise ValueError(f"Unsupported template format: {template_format}")
input_variables = get_template_variables(template, template_format)
_partial_variables = partial_variables or {} _partial_variables = partial_variables or {}
if _partial_variables: if _partial_variables:
input_variables = { input_variables = [
var for var in input_variables if var not in _partial_variables var for var in input_variables if var not in _partial_variables
} ]
return cls( return cls(
input_variables=sorted(input_variables), input_variables=input_variables,
template=template, template=template,
template_format=template_format, template_format=template_format,
partial_variables=_partial_variables, partial_variables=_partial_variables,

View File

@ -97,11 +97,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
values["llm_chain"] = LLMChain( values["llm_chain"] = LLMChain(
llm=values.get("llm"), llm=values.get("llm"),
prompt=PromptTemplate( prompt=PromptTemplate(
template=QUERY_CHECKER, input_variables=["query", "dialect"] template=QUERY_CHECKER, input_variables=["dialect", "query"]
), ),
) )
if values["llm_chain"].prompt.input_variables != ["query", "dialect"]: if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
raise ValueError( raise ValueError(
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']" "LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
) )

View File

@ -15,6 +15,5 @@
"partial_variables": {}, "partial_variables": {},
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:", "template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
"template_format": "f-string", "template_format": "f-string",
"validate_template": true,
"_type": "prompt" "_type": "prompt"
} }

View File

@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
def test_chat_invalid_input_variables_extra() -> None: def test_chat_invalid_input_variables_extra() -> None:
messages = [HumanMessage(content="foo")] messages = [HumanMessage(content="foo")]
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=["foo"]) ChatPromptTemplate(
messages=messages, input_variables=["foo"], validate_template=True
)
assert (
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
== []
)
def test_chat_invalid_input_variables_missing() -> None: def test_chat_invalid_input_variables_missing() -> None:
messages = [HumanMessagePromptTemplate.from_template("{foo}")] messages = [HumanMessagePromptTemplate.from_template("{foo}")]
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatPromptTemplate(messages=messages, input_variables=[]) ChatPromptTemplate(
messages=messages, input_variables=[], validate_template=True
)
assert ChatPromptTemplate(
messages=messages, input_variables=[]
).input_variables == ["foo"]
def test_infer_variables() -> None: def test_infer_variables() -> None:

View File

@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
suffix=template, suffix=template,
examples=[], examples=[],
example_prompt=EXAMPLE_PROMPT, example_prompt=EXAMPLE_PROMPT,
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
# Test when missing in prefix # Test when missing in prefix
template = "This is a {foo} test." template = "This is a {foo} test."
@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
examples=[], examples=[],
prefix=template, prefix=template,
example_prompt=EXAMPLE_PROMPT, example_prompt=EXAMPLE_PROMPT,
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None: def test_prompt_extra_input_variables() -> None:
@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
suffix=template, suffix=template,
examples=[], examples=[],
example_prompt=EXAMPLE_PROMPT, example_prompt=EXAMPLE_PROMPT,
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
).input_variables == ["foo"]
def test_few_shot_functionality() -> None: def test_few_shot_functionality() -> None:
@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1], examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0], example_prompt=example_jinja2_prompt[0],
template_format="jinja2", template_format="jinja2",
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar"]
# Test when missing in prefix # Test when missing in prefix
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
examples=example_jinja2_prompt[1], examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0], example_prompt=example_jinja2_prompt[0],
template_format="jinja2", template_format="jinja2",
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")
@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
examples=example_jinja2_prompt[1], examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0], example_prompt=example_jinja2_prompt[0],
template_format="jinja2", template_format="jinja2",
validate_template=True,
) )
assert FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
).input_variables == ["bar", "foo"]
def test_few_shot_chat_message_prompt_template() -> None: def test_few_shot_chat_message_prompt_template() -> None:

View File

@ -1,5 +1,7 @@
"""Test few shot prompt template.""" """Test few shot prompt template."""
import pytest
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
"Now you try to talk about party." "Now you try to talk about party."
) )
assert output == expected_output assert output == expected_output
def test_prompttemplate_validation() -> None:
"""Test that few shot works when prefix and suffix are PromptTemplates."""
prefix = PromptTemplate(
input_variables=["content"], template="This is a test about {content}."
)
suffix = PromptTemplate(
input_variables=["new_content"],
template="Now you try to talk about {new_content}.",
)
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
with pytest.raises(ValueError):
FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
validate_template=True,
)
assert FewShotPromptWithTemplates(
suffix=suffix,
prefix=prefix,
input_variables=[],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
).input_variables == ["content", "new_content"]

View File

@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None:
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables: list = [] input_variables: list = []
with pytest.raises(ValueError): with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template) PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_extra_input_variables() -> None: def test_prompt_extra_input_variables() -> None:
@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None:
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables = ["foo", "bar"] input_variables = ["foo", "bar"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template) PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_wrong_input_variables() -> None: def test_prompt_wrong_input_variables() -> None:
@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None:
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables = ["bar"] input_variables = ["bar"]
with pytest.raises(ValueError): with pytest.raises(ValueError):
PromptTemplate(input_variables=input_variables, template=template) PromptTemplate(
input_variables=input_variables, template=template, validate_template=True
)
assert PromptTemplate(
input_variables=input_variables, template=template
).input_variables == ["foo"]
def test_prompt_from_examples_valid() -> None: def test_prompt_from_examples_valid() -> None:
@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None:
input_variables: list = [] input_variables: list = []
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
PromptTemplate( PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2" input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
) )
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")
@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None:
input_variables = ["foo", "bar"] input_variables = ["foo", "bar"]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
PromptTemplate( PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2" input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
) )
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")
@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
input_variables = ["bar"] input_variables = ["bar"]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
PromptTemplate( PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2" input_variables=input_variables,
template=template,
template_format="jinja2",
validate_template=True,
) )
assert PromptTemplate(
input_variables=input_variables, template=template, template_format="jinja2"
).input_variables == ["foo"]