Add Dist Metrics for String Distance Evaluation (#8837)

Co-authored-by: shibuiwilliam <shibuiyusuke@gmail.com>
This commit is contained in:
William FH 2023-08-06 14:05:00 -07:00 committed by GitHub
parent f76d50d8dc
commit 983678dedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 34 deletions

View File

@ -43,12 +43,16 @@ class StringDistance(str, Enum):
LEVENSHTEIN: The Levenshtein distance.
JARO: The Jaro distance.
JARO_WINKLER: The Jaro-Winkler distance.
HAMMING: The Hamming distance.
INDEL: The Indel distance.
"""
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
LEVENSHTEIN = "levenshtein"
JARO = "jaro"
JARO_WINKLER = "jaro_winkler"
HAMMING = "hamming"
INDEL = "indel"
class _RapidFuzzChainMixin(Chain):
@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain):
return result
@staticmethod
def _get_metric(distance: str) -> Callable:
def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
"""
Get the distance metric function based on the distance type.
@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain):
Raises:
ValueError: If the distance metric is invalid.
"""
rf_distance = _load_rapidfuzz()
if distance == StringDistance.DAMERAU_LEVENSHTEIN:
return rf_distance.DamerauLevenshtein.distance
elif distance == StringDistance.LEVENSHTEIN:
return rf_distance.Levenshtein.distance
elif distance == StringDistance.JARO:
return rf_distance.Jaro.distance
elif distance == StringDistance.JARO_WINKLER:
return rf_distance.JaroWinkler.distance
from rapidfuzz import distance as rf_distance
module_map: Dict[str, Any] = {
StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
StringDistance.JARO: rf_distance.Jaro,
StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
StringDistance.HAMMING: rf_distance.Hamming,
StringDistance.INDEL: rf_distance.Indel,
}
if distance not in module_map:
raise ValueError(
f"Invalid distance metric: {distance}"
f"\nMust be one of: {list(StringDistance)}"
)
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else:
raise ValueError(f"Invalid distance metric: {distance}")
return module.distance
@property
def metric(self) -> Callable:
@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain):
Returns:
Callable: The distance metric function.
"""
return _RapidFuzzChainMixin._get_metric(self.distance)
return _RapidFuzzChainMixin._get_metric(
self.distance, normalize_score=self.normalize_score
)
def compute_metric(self, a: str, b: str) -> float:
"""
@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain):
Returns:
float: The distance between the two strings.
"""
score = self.metric(a, b)
if self.normalize_score and self.distance in (
StringDistance.DAMERAU_LEVENSHTEIN,
StringDistance.LEVENSHTEIN,
):
score = score / max(len(a), len(b))
return score
return self.metric(a, b)
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):

View File

@ -30,8 +30,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None:
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
def test_zero_distance_pairwise(distance: StringDistance) -> None:
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
@pytest.mark.parametrize("normalize_score", [True, False])
def test_zero_distance_pairwise(
distance: StringDistance, normalize_score: bool
) -> None:
eval_chain = PairwiseStringDistanceEvalChain(
distance=distance, normalize_score=normalize_score
)
string = "三人行则必有我师"
result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string)
assert "score" in result
@ -49,3 +54,60 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None:
)
assert "score" in result
assert result["score"] == 0
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
@pytest.mark.parametrize("normalize_score", [True, False])
def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> None:
eval_chain = StringDistanceEvalChain(
distance=distance, normalize_score=normalize_score
)
prediction = "I like to eat apples."
reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result
assert 0 < result["score"]
if normalize_score:
assert result["score"] < 1.0
@pytest.mark.asyncio
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
async def test_non_zero_distance_async(distance: StringDistance) -> None:
eval_chain = StringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = await eval_chain.aevaluate_strings(
prediction=prediction, reference=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
def test_non_zero_distance_pairwise(distance: StringDistance) -> None:
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = eval_chain.evaluate_string_pairs(
prediction=prediction, prediction_b=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0
@pytest.mark.asyncio
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> None:
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = await eval_chain.aevaluate_string_pairs(
prediction=prediction, prediction_b=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0

View File

@ -1,4 +1,5 @@
"""Test the loading function for evaluators."""
from typing import List
import pytest
@ -26,24 +27,31 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
@pytest.mark.parametrize(
"evaluator_type",
"evaluator_types",
[
EvaluatorType.LABELED_CRITERIA,
EvaluatorType.LABELED_PAIRWISE_STRING,
EvaluatorType.QA,
EvaluatorType.CONTEXT_QA,
EvaluatorType.COT_QA,
[EvaluatorType.LABELED_CRITERIA],
[EvaluatorType.LABELED_PAIRWISE_STRING],
[EvaluatorType.QA],
[EvaluatorType.CONTEXT_QA],
[EvaluatorType.COT_QA],
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
[
EvaluatorType.COT_QA,
EvaluatorType.LABELED_CRITERIA,
EvaluatorType.LABELED_PAIRWISE_STRING,
],
],
)
def test_eval_chain_requires_references(evaluator_type: EvaluatorType) -> None:
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
"""Test loading evaluators."""
fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
)
evaluator = load_evaluators(
[evaluator_type],
evaluators = load_evaluators(
evaluator_types,
llm=fake_llm,
)[0]
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
raise ValueError("Evaluator is not a [pairwise]string evaluator")
assert evaluator.requires_reference
)
for evaluator in evaluators:
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
raise ValueError("Evaluator is not a [pairwise]string evaluator")
assert evaluator.requires_reference