langchain/tests/unit_tests/prompts/test_few_shot.py

266 lines
7.9 KiB
Python
Raw Normal View History

2022-11-20 04:32:45 +00:00
"""Test few shot prompt template."""
from typing import Dict, List, Tuple
2022-11-20 04:32:45 +00:00
import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"
)
@pytest.fixture()
def example_jinja2_prompt() -> Tuple[PromptTemplate, List[Dict[str, str]]]:
example_template = "{{ word }}: {{ antonym }}"
examples = [
{"word": "happy", "antonym": "sad"},
{"word": "tall", "antonym": "short"},
]
return (
PromptTemplate(
input_variables=["word", "antonym"],
template=example_template,
template_format="jinja2",
),
examples,
)
2022-11-20 04:32:45 +00:00
def test_suffix_only() -> None:
"""Test prompt works with just a suffix."""
suffix = "This is a {foo} test."
input_variables = ["foo"]
prompt = FewShotPromptTemplate(
input_variables=input_variables,
suffix=suffix,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
output = prompt.format(foo="bar")
expected_output = "This is a bar test."
assert output == expected_output
def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
# Test when missing in suffix
template = "This is a {foo} test."
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
# Test when missing in prefix
template = "This is a {foo} test."
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=[],
suffix="foo",
examples=[],
prefix=template,
example_prompt=EXAMPLE_PROMPT,
)
def test_prompt_extra_input_variables() -> None:
"""Test error is raised when there are too many input variables."""
template = "This is a {foo} test."
input_variables = ["foo", "bar"]
with pytest.raises(ValueError):
FewShotPromptTemplate(
input_variables=input_variables,
suffix=template,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
def test_few_shot_functionality() -> None:
"""Test that few shot works with examples."""
prefix = "This is a test about {content}."
suffix = "Now you try to talk about {new_content}."
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptTemplate(
suffix=suffix,
prefix=prefix,
input_variables=["content", "new_content"],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(content="animals", new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output
2023-02-28 16:40:35 +00:00
def test_partial_init_string() -> None:
"""Test prompt can be initialized with partial variables."""
prefix = "This is a test about {content}."
suffix = "Now you try to talk about {new_content}."
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptTemplate(
suffix=suffix,
prefix=prefix,
input_variables=["new_content"],
partial_variables={"content": "animals"},
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output
def test_partial_init_func() -> None:
"""Test prompt can be initialized with partial variables."""
prefix = "This is a test about {content}."
suffix = "Now you try to talk about {new_content}."
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptTemplate(
suffix=suffix,
prefix=prefix,
input_variables=["new_content"],
partial_variables={"content": lambda: "animals"},
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
output = prompt.format(new_content="party")
expected_output = (
"This is a test about animals.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output
def test_partial() -> None:
"""Test prompt can be partialed."""
prefix = "This is a test about {content}."
suffix = "Now you try to talk about {new_content}."
examples = [
{"question": "foo", "answer": "bar"},
{"question": "baz", "answer": "foo"},
]
prompt = FewShotPromptTemplate(
suffix=suffix,
prefix=prefix,
input_variables=["content", "new_content"],
examples=examples,
example_prompt=EXAMPLE_PROMPT,
example_separator="\n",
)
new_prompt = prompt.partial(content="foo")
new_output = new_prompt.format(new_content="party")
expected_output = (
"This is a test about foo.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert new_output == expected_output
output = prompt.format(new_content="party", content="bar")
expected_output = (
"This is a test about bar.\n"
"foo: bar\n"
"baz: foo\n"
"Now you try to talk about party."
)
assert output == expected_output
def test_prompt_jinja2_functionality(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
prompt = FewShotPromptTemplate(
input_variables=["foo", "bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
output = prompt.format(foo="hello", bar="bye")
expected_output = (
"Starting with hello\n\n" "happy: sad\n\n" "tall: short\n\n" "Ending with bye"
)
assert output == expected_output
def test_prompt_jinja2_missing_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when input variables are not provided."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
# Test when missing in suffix
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=[],
suffix=suffix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
# Test when missing in prefix
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=["bar"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
def test_prompt_jinja2_extra_input_variables(
example_jinja2_prompt: Tuple[PromptTemplate, List[Dict[str, str]]]
) -> None:
"""Test error is raised when there are too many input variables."""
prefix = "Starting with {{ foo }}"
suffix = "Ending with {{ bar }}"
with pytest.warns(UserWarning):
FewShotPromptTemplate(
input_variables=["bar", "foo", "extra", "thing"],
suffix=suffix,
prefix=prefix,
examples=example_jinja2_prompt[1],
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)