forked from Archives/langchain
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
|
"""Test functionality related to ngram overlap based selector."""
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
from langchain.prompts.example_selector.ngram_overlap import (
|
||
|
NGramOverlapExampleSelector,
|
||
|
ngram_overlap_score,
|
||
|
)
|
||
|
from langchain.prompts.prompt import PromptTemplate
|
||
|
|
||
|
EXAMPLES = [
|
||
|
{"input": "See Spot run.", "output": "foo1"},
|
||
|
{"input": "My dog barks.", "output": "foo2"},
|
||
|
{"input": "Spot can run.", "output": "foo3"},
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def selector() -> NGramOverlapExampleSelector:
|
||
|
"""Get ngram overlap based selector to use in tests."""
|
||
|
prompts = PromptTemplate(
|
||
|
input_variables=["input", "output"], template="Input: {input}\nOutput: {output}"
|
||
|
)
|
||
|
selector = NGramOverlapExampleSelector(
|
||
|
examples=EXAMPLES,
|
||
|
example_prompt=prompts,
|
||
|
)
|
||
|
return selector
|
||
|
|
||
|
|
||
|
def test_selector_valid(selector: NGramOverlapExampleSelector) -> None:
|
||
|
"""Test NGramOverlapExampleSelector can select examples."""
|
||
|
sentence = "Spot can run."
|
||
|
output = selector.select_examples({"input": sentence})
|
||
|
assert output == [EXAMPLES[2], EXAMPLES[0], EXAMPLES[1]]
|
||
|
|
||
|
|
||
|
def test_selector_add_example(selector: NGramOverlapExampleSelector) -> None:
|
||
|
"""Test NGramOverlapExampleSelector can add an example."""
|
||
|
new_example = {"input": "Spot plays fetch.", "output": "foo4"}
|
||
|
selector.add_example(new_example)
|
||
|
sentence = "Spot can run."
|
||
|
output = selector.select_examples({"input": sentence})
|
||
|
assert output == [EXAMPLES[2], EXAMPLES[0]] + [new_example] + [EXAMPLES[1]]
|
||
|
|
||
|
|
||
|
def test_selector_threshold_zero(selector: NGramOverlapExampleSelector) -> None:
|
||
|
"""Tests NGramOverlapExampleSelector threshold set to 0.0."""
|
||
|
selector.threshold = 0.0
|
||
|
sentence = "Spot can run."
|
||
|
output = selector.select_examples({"input": sentence})
|
||
|
assert output == [EXAMPLES[2], EXAMPLES[0]]
|
||
|
|
||
|
|
||
|
def test_selector_threshold_more_than_one(
|
||
|
selector: NGramOverlapExampleSelector,
|
||
|
) -> None:
|
||
|
"""Tests NGramOverlapExampleSelector threshold greater than 1.0."""
|
||
|
selector.threshold = 1.0 + 1e-9
|
||
|
sentence = "Spot can run."
|
||
|
output = selector.select_examples({"input": sentence})
|
||
|
assert output == []
|
||
|
|
||
|
|
||
|
def test_ngram_overlap_score(selector: NGramOverlapExampleSelector) -> None:
|
||
|
"""Tests that ngram_overlap_score returns correct values."""
|
||
|
selector.threshold = 1.0 + 1e-9
|
||
|
none = ngram_overlap_score(["Spot can run."], ["My dog barks."])
|
||
|
some = ngram_overlap_score(["Spot can run."], ["See Spot run."])
|
||
|
complete = ngram_overlap_score(["Spot can run."], ["Spot can run."])
|
||
|
|
||
|
check = [abs(none - 0.0) < 1e-9, 0.0 < some < 1.0, abs(complete - 1.0) < 1e-9]
|
||
|
assert check == [True, True, True]
|