mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Make prompt validation opt-in (#11973)
By default replace input_variables with the correct value <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
commit
6bd9c1d2b3
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from string import Formatter
|
||||||
from typing import Any, Callable, Dict, List, Literal, Set
|
from typing import Any, Callable, Dict, List, Literal, Set
|
||||||
|
|
||||||
from langchain.schema.messages import BaseMessage, HumanMessage
|
from langchain.schema.messages import BaseMessage, HumanMessage
|
||||||
@ -99,6 +100,20 @@ def check_valid_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_template_variables(template: str, template_format: str) -> List[str]:
|
||||||
|
if template_format == "jinja2":
|
||||||
|
# Get the variables for the template
|
||||||
|
input_variables = _get_jinja2_variables_from_template(template)
|
||||||
|
elif template_format == "f-string":
|
||||||
|
input_variables = {
|
||||||
|
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported template format: {template_format}")
|
||||||
|
|
||||||
|
return sorted(input_variables)
|
||||||
|
|
||||||
|
|
||||||
class StringPromptValue(PromptValue):
|
class StringPromptValue(PromptValue):
|
||||||
"""String prompt value."""
|
"""String prompt value."""
|
||||||
|
|
||||||
|
@ -382,6 +382,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
"""List of input variables in template messages. Used for validation."""
|
"""List of input variables in template messages. Used for validation."""
|
||||||
messages: List[MessageLike]
|
messages: List[MessageLike]
|
||||||
"""List of messages consisting of either message prompt templates or messages."""
|
"""List of messages consisting of either message prompt templates or messages."""
|
||||||
|
validate_template: bool = False
|
||||||
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||||
"""Combine two prompt templates.
|
"""Combine two prompt templates.
|
||||||
@ -432,7 +434,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
input_types[message.variable_name] = List[AnyMessage]
|
input_types[message.variable_name] = List[AnyMessage]
|
||||||
if "partial_variables" in values:
|
if "partial_variables" in values:
|
||||||
input_vars = input_vars - set(values["partial_variables"])
|
input_vars = input_vars - set(values["partial_variables"])
|
||||||
if "input_variables" in values:
|
if "input_variables" in values and values.get("validate_template"):
|
||||||
if input_vars != set(values["input_variables"]):
|
if input_vars != set(values["input_variables"]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Got mismatched input_variables. "
|
"Got mismatched input_variables. "
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
|
get_template_variables,
|
||||||
)
|
)
|
||||||
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
|
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
|
||||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||||
@ -77,7 +78,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
"""Return whether or not the class is serializable."""
|
"""Return whether or not the class is serializable."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
validate_template: bool = True
|
validate_template: bool = False
|
||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
@ -95,7 +96,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
"""A prompt template string to put before the examples."""
|
"""A prompt template string to put before the examples."""
|
||||||
|
|
||||||
template_format: str = "f-string"
|
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
@ -107,6 +108,14 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
|||||||
values["template_format"],
|
values["template_format"],
|
||||||
values["input_variables"] + list(values["partial_variables"]),
|
values["input_variables"] + list(values["partial_variables"]),
|
||||||
)
|
)
|
||||||
|
elif values.get("template_format"):
|
||||||
|
values["input_variables"] = [
|
||||||
|
var
|
||||||
|
for var in get_template_variables(
|
||||||
|
values["prefix"] + values["suffix"], values["template_format"]
|
||||||
|
)
|
||||||
|
if var not in values["partial_variables"]
|
||||||
|
]
|
||||||
return values
|
return values
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -37,7 +37,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
template_format: str = "f-string"
|
template_format: str = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
|
|
||||||
validate_template: bool = True
|
validate_template: bool = False
|
||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
@ -72,6 +72,12 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
|||||||
f"Got input_variables={input_variables}, but based on "
|
f"Got input_variables={input_variables}, but based on "
|
||||||
f"prefix/suffix expected {expected_input_variables}"
|
f"prefix/suffix expected {expected_input_variables}"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
values["input_variables"] = sorted(
|
||||||
|
set(values["suffix"].input_variables)
|
||||||
|
| set(values["prefix"].input_variables if values["prefix"] else [])
|
||||||
|
- set(values["partial_variables"])
|
||||||
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -2,14 +2,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import Formatter
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
from langchain.prompts.base import (
|
from langchain.prompts.base import (
|
||||||
DEFAULT_FORMATTER_MAPPING,
|
DEFAULT_FORMATTER_MAPPING,
|
||||||
StringPromptTemplate,
|
StringPromptTemplate,
|
||||||
_get_jinja2_variables_from_template,
|
|
||||||
check_valid_template,
|
check_valid_template,
|
||||||
|
get_template_variables,
|
||||||
)
|
)
|
||||||
from langchain.pydantic_v1 import root_validator
|
from langchain.pydantic_v1 import root_validator
|
||||||
|
|
||||||
@ -53,10 +52,10 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
template: str
|
template: str
|
||||||
"""The prompt template."""
|
"""The prompt template."""
|
||||||
|
|
||||||
template_format: str = "f-string"
|
template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string"
|
||||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||||
|
|
||||||
validate_template: bool = True
|
validate_template: bool = False
|
||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
def __add__(self, other: Any) -> PromptTemplate:
|
def __add__(self, other: Any) -> PromptTemplate:
|
||||||
@ -127,6 +126,14 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
check_valid_template(
|
check_valid_template(
|
||||||
values["template"], values["template_format"], all_inputs
|
values["template"], values["template_format"], all_inputs
|
||||||
)
|
)
|
||||||
|
elif values.get("template_format"):
|
||||||
|
values["input_variables"] = [
|
||||||
|
var
|
||||||
|
for var in get_template_variables(
|
||||||
|
values["template"], values["template_format"]
|
||||||
|
)
|
||||||
|
if var not in values["partial_variables"]
|
||||||
|
]
|
||||||
return values
|
return values
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -202,25 +209,17 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
Returns:
|
Returns:
|
||||||
The prompt template loaded from the template.
|
The prompt template loaded from the template.
|
||||||
"""
|
"""
|
||||||
if template_format == "jinja2":
|
|
||||||
# Get the variables for the template
|
|
||||||
input_variables = _get_jinja2_variables_from_template(template)
|
|
||||||
elif template_format == "f-string":
|
|
||||||
input_variables = {
|
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported template format: {template_format}")
|
|
||||||
|
|
||||||
|
input_variables = get_template_variables(template, template_format)
|
||||||
_partial_variables = partial_variables or {}
|
_partial_variables = partial_variables or {}
|
||||||
|
|
||||||
if _partial_variables:
|
if _partial_variables:
|
||||||
input_variables = {
|
input_variables = [
|
||||||
var for var in input_variables if var not in _partial_variables
|
var for var in input_variables if var not in _partial_variables
|
||||||
}
|
]
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
input_variables=sorted(input_variables),
|
input_variables=input_variables,
|
||||||
template=template,
|
template=template,
|
||||||
template_format=template_format,
|
template_format=template_format,
|
||||||
partial_variables=_partial_variables,
|
partial_variables=_partial_variables,
|
||||||
|
@ -97,11 +97,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
|||||||
values["llm_chain"] = LLMChain(
|
values["llm_chain"] = LLMChain(
|
||||||
llm=values.get("llm"),
|
llm=values.get("llm"),
|
||||||
prompt=PromptTemplate(
|
prompt=PromptTemplate(
|
||||||
template=QUERY_CHECKER, input_variables=["query", "dialect"]
|
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if values["llm_chain"].prompt.input_variables != ["query", "dialect"]:
|
if values["llm_chain"].prompt.input_variables != ["dialect", "query"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
"LLM chain for QueryCheckerTool must have input variables ['query', 'dialect']"
|
||||||
)
|
)
|
||||||
|
@ -15,6 +15,5 @@
|
|||||||
"partial_variables": {},
|
"partial_variables": {},
|
||||||
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
|
"template": "Given the following question and student answer, provide a correct answer and score the student answer.\nQuestion: {question}\nStudent Answer: {student_answer}\nCorrect Answer:",
|
||||||
"template_format": "f-string",
|
"template_format": "f-string",
|
||||||
"validate_template": true,
|
|
||||||
"_type": "prompt"
|
"_type": "prompt"
|
||||||
}
|
}
|
@ -186,13 +186,24 @@ def test_chat_prompt_template_with_messages() -> None:
|
|||||||
def test_chat_invalid_input_variables_extra() -> None:
|
def test_chat_invalid_input_variables_extra() -> None:
|
||||||
messages = [HumanMessage(content="foo")]
|
messages = [HumanMessage(content="foo")]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatPromptTemplate(messages=messages, input_variables=["foo"])
|
ChatPromptTemplate(
|
||||||
|
messages=messages, input_variables=["foo"], validate_template=True
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
ChatPromptTemplate(messages=messages, input_variables=["foo"]).input_variables
|
||||||
|
== []
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_invalid_input_variables_missing() -> None:
|
def test_chat_invalid_input_variables_missing() -> None:
|
||||||
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
messages = [HumanMessagePromptTemplate.from_template("{foo}")]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ChatPromptTemplate(messages=messages, input_variables=[])
|
ChatPromptTemplate(
|
||||||
|
messages=messages, input_variables=[], validate_template=True
|
||||||
|
)
|
||||||
|
assert ChatPromptTemplate(
|
||||||
|
messages=messages, input_variables=[]
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_infer_variables() -> None:
|
def test_infer_variables() -> None:
|
||||||
|
@ -67,7 +67,14 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
suffix=template,
|
suffix=template,
|
||||||
examples=[],
|
examples=[],
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=[],
|
||||||
|
suffix=template,
|
||||||
|
examples=[],
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
# Test when missing in prefix
|
# Test when missing in prefix
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
@ -78,7 +85,15 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
examples=[],
|
examples=[],
|
||||||
prefix=template,
|
prefix=template,
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=[],
|
||||||
|
suffix="foo",
|
||||||
|
examples=[],
|
||||||
|
prefix=template,
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_extra_input_variables() -> None:
|
def test_prompt_extra_input_variables() -> None:
|
||||||
@ -91,7 +106,14 @@ def test_prompt_extra_input_variables() -> None:
|
|||||||
suffix=template,
|
suffix=template,
|
||||||
examples=[],
|
examples=[],
|
||||||
example_prompt=EXAMPLE_PROMPT,
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=input_variables,
|
||||||
|
suffix=template,
|
||||||
|
examples=[],
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_few_shot_functionality() -> None:
|
def test_few_shot_functionality() -> None:
|
||||||
@ -248,7 +270,15 @@ def test_prompt_jinja2_missing_input_variables(
|
|||||||
examples=example_jinja2_prompt[1],
|
examples=example_jinja2_prompt[1],
|
||||||
example_prompt=example_jinja2_prompt[0],
|
example_prompt=example_jinja2_prompt[0],
|
||||||
template_format="jinja2",
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=[],
|
||||||
|
suffix=suffix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
).input_variables == ["bar"]
|
||||||
|
|
||||||
# Test when missing in prefix
|
# Test when missing in prefix
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
@ -259,7 +289,16 @@ def test_prompt_jinja2_missing_input_variables(
|
|||||||
examples=example_jinja2_prompt[1],
|
examples=example_jinja2_prompt[1],
|
||||||
example_prompt=example_jinja2_prompt[0],
|
example_prompt=example_jinja2_prompt[0],
|
||||||
template_format="jinja2",
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=["bar"],
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
).input_variables == ["bar", "foo"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
@ -277,7 +316,16 @@ def test_prompt_jinja2_extra_input_variables(
|
|||||||
examples=example_jinja2_prompt[1],
|
examples=example_jinja2_prompt[1],
|
||||||
example_prompt=example_jinja2_prompt[0],
|
example_prompt=example_jinja2_prompt[0],
|
||||||
template_format="jinja2",
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert FewShotPromptTemplate(
|
||||||
|
input_variables=["bar", "foo", "extra", "thing"],
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
examples=example_jinja2_prompt[1],
|
||||||
|
example_prompt=example_jinja2_prompt[0],
|
||||||
|
template_format="jinja2",
|
||||||
|
).input_variables == ["bar", "foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_few_shot_chat_message_prompt_template() -> None:
|
def test_few_shot_chat_message_prompt_template() -> None:
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test few shot prompt template."""
|
"""Test few shot prompt template."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
@ -38,3 +40,37 @@ def test_prompttemplate_prefix_suffix() -> None:
|
|||||||
"Now you try to talk about party."
|
"Now you try to talk about party."
|
||||||
)
|
)
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompttemplate_validation() -> None:
|
||||||
|
"""Test that few shot works when prefix and suffix are PromptTemplates."""
|
||||||
|
prefix = PromptTemplate(
|
||||||
|
input_variables=["content"], template="This is a test about {content}."
|
||||||
|
)
|
||||||
|
suffix = PromptTemplate(
|
||||||
|
input_variables=["new_content"],
|
||||||
|
template="Now you try to talk about {new_content}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{"question": "foo", "answer": "bar"},
|
||||||
|
{"question": "baz", "answer": "foo"},
|
||||||
|
]
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
FewShotPromptWithTemplates(
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
input_variables=[],
|
||||||
|
examples=examples,
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
example_separator="\n",
|
||||||
|
validate_template=True,
|
||||||
|
)
|
||||||
|
assert FewShotPromptWithTemplates(
|
||||||
|
suffix=suffix,
|
||||||
|
prefix=prefix,
|
||||||
|
input_variables=[],
|
||||||
|
examples=examples,
|
||||||
|
example_prompt=EXAMPLE_PROMPT,
|
||||||
|
example_separator="\n",
|
||||||
|
).input_variables == ["content", "new_content"]
|
||||||
|
@ -39,7 +39,12 @@ def test_prompt_missing_input_variables() -> None:
|
|||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables: list = []
|
input_variables: list = []
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
PromptTemplate(input_variables=input_variables, template=template)
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, validate_template=True
|
||||||
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_extra_input_variables() -> None:
|
def test_prompt_extra_input_variables() -> None:
|
||||||
@ -47,7 +52,12 @@ def test_prompt_extra_input_variables() -> None:
|
|||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables = ["foo", "bar"]
|
input_variables = ["foo", "bar"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
PromptTemplate(input_variables=input_variables, template=template)
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, validate_template=True
|
||||||
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_wrong_input_variables() -> None:
|
def test_prompt_wrong_input_variables() -> None:
|
||||||
@ -55,7 +65,12 @@ def test_prompt_wrong_input_variables() -> None:
|
|||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
input_variables = ["bar"]
|
input_variables = ["bar"]
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
PromptTemplate(input_variables=input_variables, template=template)
|
PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, validate_template=True
|
||||||
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_from_examples_valid() -> None:
|
def test_prompt_from_examples_valid() -> None:
|
||||||
@ -229,8 +244,14 @@ def test_prompt_jinja2_missing_input_variables() -> None:
|
|||||||
input_variables: list = []
|
input_variables: list = []
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables,
|
||||||
|
template=template,
|
||||||
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
@ -240,8 +261,14 @@ def test_prompt_jinja2_extra_input_variables() -> None:
|
|||||||
input_variables = ["foo", "bar"]
|
input_variables = ["foo", "bar"]
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables,
|
||||||
|
template=template,
|
||||||
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
@ -251,5 +278,11 @@ def test_prompt_jinja2_wrong_input_variables() -> None:
|
|||||||
input_variables = ["bar"]
|
input_variables = ["bar"]
|
||||||
with pytest.warns(UserWarning):
|
with pytest.warns(UserWarning):
|
||||||
PromptTemplate(
|
PromptTemplate(
|
||||||
input_variables=input_variables, template=template, template_format="jinja2"
|
input_variables=input_variables,
|
||||||
|
template=template,
|
||||||
|
template_format="jinja2",
|
||||||
|
validate_template=True,
|
||||||
)
|
)
|
||||||
|
assert PromptTemplate(
|
||||||
|
input_variables=input_variables, template=template, template_format="jinja2"
|
||||||
|
).input_variables == ["foo"]
|
||||||
|
Loading…
Reference in New Issue
Block a user