Add Dist Metrics for String Distance Evaluation (#8837)

Co-authored-by: shibuiwilliam <shibuiyusuke@gmail.com>
This commit is contained in:
William FH 2023-08-06 14:05:00 -07:00 committed by GitHub
parent f76d50d8dc
commit 983678dedc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 113 additions and 34 deletions

View File

@ -43,12 +43,16 @@ class StringDistance(str, Enum):
LEVENSHTEIN: The Levenshtein distance. LEVENSHTEIN: The Levenshtein distance.
JARO: The Jaro distance. JARO: The Jaro distance.
JARO_WINKLER: The Jaro-Winkler distance. JARO_WINKLER: The Jaro-Winkler distance.
HAMMING: The Hamming distance.
INDEL: The Indel distance.
""" """
DAMERAU_LEVENSHTEIN = "damerau_levenshtein" DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
LEVENSHTEIN = "levenshtein" LEVENSHTEIN = "levenshtein"
JARO = "jaro" JARO = "jaro"
JARO_WINKLER = "jaro_winkler" JARO_WINKLER = "jaro_winkler"
HAMMING = "hamming"
INDEL = "indel"
class _RapidFuzzChainMixin(Chain): class _RapidFuzzChainMixin(Chain):
@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain):
return result return result
@staticmethod @staticmethod
def _get_metric(distance: str) -> Callable: def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
""" """
Get the distance metric function based on the distance type. Get the distance metric function based on the distance type.
@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain):
Raises: Raises:
ValueError: If the distance metric is invalid. ValueError: If the distance metric is invalid.
""" """
rf_distance = _load_rapidfuzz() from rapidfuzz import distance as rf_distance
if distance == StringDistance.DAMERAU_LEVENSHTEIN:
return rf_distance.DamerauLevenshtein.distance module_map: Dict[str, Any] = {
elif distance == StringDistance.LEVENSHTEIN: StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
return rf_distance.Levenshtein.distance StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
elif distance == StringDistance.JARO: StringDistance.JARO: rf_distance.Jaro,
return rf_distance.Jaro.distance StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
elif distance == StringDistance.JARO_WINKLER: StringDistance.HAMMING: rf_distance.Hamming,
return rf_distance.JaroWinkler.distance StringDistance.INDEL: rf_distance.Indel,
}
if distance not in module_map:
raise ValueError(
f"Invalid distance metric: {distance}"
f"\nMust be one of: {list(StringDistance)}"
)
module = module_map[distance]
if normalize_score:
return module.normalized_distance
else: else:
raise ValueError(f"Invalid distance metric: {distance}") return module.distance
@property @property
def metric(self) -> Callable: def metric(self) -> Callable:
@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain):
Returns: Returns:
Callable: The distance metric function. Callable: The distance metric function.
""" """
return _RapidFuzzChainMixin._get_metric(self.distance) return _RapidFuzzChainMixin._get_metric(
self.distance, normalize_score=self.normalize_score
)
def compute_metric(self, a: str, b: str) -> float: def compute_metric(self, a: str, b: str) -> float:
""" """
@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain):
Returns: Returns:
float: The distance between the two strings. float: The distance between the two strings.
""" """
score = self.metric(a, b) return self.metric(a, b)
if self.normalize_score and self.distance in (
StringDistance.DAMERAU_LEVENSHTEIN,
StringDistance.LEVENSHTEIN,
):
score = score / max(len(a), len(b))
return score
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin): class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):

View File

@ -30,8 +30,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None:
@pytest.mark.requires("rapidfuzz") @pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance)) @pytest.mark.parametrize("distance", list(StringDistance))
def test_zero_distance_pairwise(distance: StringDistance) -> None: @pytest.mark.parametrize("normalize_score", [True, False])
eval_chain = PairwiseStringDistanceEvalChain(distance=distance) def test_zero_distance_pairwise(
distance: StringDistance, normalize_score: bool
) -> None:
eval_chain = PairwiseStringDistanceEvalChain(
distance=distance, normalize_score=normalize_score
)
string = "三人行则必有我师" string = "三人行则必有我师"
result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string) result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string)
assert "score" in result assert "score" in result
@ -49,3 +54,60 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None:
) )
assert "score" in result assert "score" in result
assert result["score"] == 0 assert result["score"] == 0
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
@pytest.mark.parametrize("normalize_score", [True, False])
def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> None:
eval_chain = StringDistanceEvalChain(
distance=distance, normalize_score=normalize_score
)
prediction = "I like to eat apples."
reference = "I like apples."
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
assert "score" in result
assert 0 < result["score"]
if normalize_score:
assert result["score"] < 1.0
@pytest.mark.asyncio
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
async def test_non_zero_distance_async(distance: StringDistance) -> None:
eval_chain = StringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = await eval_chain.aevaluate_strings(
prediction=prediction, reference=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
def test_non_zero_distance_pairwise(distance: StringDistance) -> None:
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = eval_chain.evaluate_string_pairs(
prediction=prediction, prediction_b=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0
@pytest.mark.asyncio
@pytest.mark.requires("rapidfuzz")
@pytest.mark.parametrize("distance", list(StringDistance))
async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> None:
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
prediction = "I like to eat apples."
reference = "I like apples."
result = await eval_chain.aevaluate_string_pairs(
prediction=prediction, prediction_b=reference
)
assert "score" in result
assert 0 < result["score"] < 1.0

View File

@ -1,4 +1,5 @@
"""Test the loading function for evaluators.""" """Test the loading function for evaluators."""
from typing import List
import pytest import pytest
@ -26,24 +27,31 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"evaluator_type", "evaluator_types",
[ [
EvaluatorType.LABELED_CRITERIA, [EvaluatorType.LABELED_CRITERIA],
EvaluatorType.LABELED_PAIRWISE_STRING, [EvaluatorType.LABELED_PAIRWISE_STRING],
EvaluatorType.QA, [EvaluatorType.QA],
EvaluatorType.CONTEXT_QA, [EvaluatorType.CONTEXT_QA],
EvaluatorType.COT_QA, [EvaluatorType.COT_QA],
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
[
EvaluatorType.COT_QA,
EvaluatorType.LABELED_CRITERIA,
EvaluatorType.LABELED_PAIRWISE_STRING,
],
], ],
) )
def test_eval_chain_requires_references(evaluator_type: EvaluatorType) -> None: def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
"""Test loading evaluators.""" """Test loading evaluators."""
fake_llm = FakeLLM( fake_llm = FakeLLM(
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
) )
evaluator = load_evaluators( evaluators = load_evaluators(
[evaluator_type], evaluator_types,
llm=fake_llm, llm=fake_llm,
)[0] )
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)): for evaluator in evaluators:
raise ValueError("Evaluator is not a [pairwise]string evaluator") if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
assert evaluator.requires_reference raise ValueError("Evaluator is not a [pairwise]string evaluator")
assert evaluator.requires_reference