You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/tests/integration_tests/prompts/test_ngram_overlap_example_...

74 lines
2.6 KiB
Python

"""Test functionality related to ngram overlap based selector."""
import pytest
from langchain.prompts.example_selector.ngram_overlap import (
NGramOverlapExampleSelector,
ngram_overlap_score,
)
from langchain.prompts.prompt import PromptTemplate
EXAMPLES = [
{"input": "See Spot run.", "output": "foo1"},
{"input": "My dog barks.", "output": "foo2"},
{"input": "Spot can run.", "output": "foo3"},
]
@pytest.fixture
def selector() -> NGramOverlapExampleSelector:
"""Get ngram overlap based selector to use in tests."""
prompts = PromptTemplate(
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
)
selector = NGramOverlapExampleSelector(
examples=EXAMPLES,
example_prompt=prompts,
)
return selector
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can select examples."""
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
"""Test NGramOverlapExampleSelector can add an example."""
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
selector.add_example(new_example)
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
selector.threshold = 0.0
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == [EXAMPLES[2], EXAMPLES[0]]
def test_selector_threshold_more_than_one(
selector: NGramOverlapExampleSelector,
) -> None:
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
selector.threshold = 1.0 + 1e-9
sentence = "Spot can run."
output = selector.select_examples({"input": sentence})
assert output == []
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
"""Tests that ngram_overlap_score returns correct values."""
selector.threshold = 1.0 + 1e-9
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
assert check == [True, True, True]