mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add Dist Metrics for String Distance Evaluation (#8837)
Co-authored-by: shibuiwilliam <shibuiyusuke@gmail.com>
This commit is contained in:
parent
f76d50d8dc
commit
983678dedc
@ -43,12 +43,16 @@ class StringDistance(str, Enum):
|
||||
LEVENSHTEIN: The Levenshtein distance.
|
||||
JARO: The Jaro distance.
|
||||
JARO_WINKLER: The Jaro-Winkler distance.
|
||||
HAMMING: The Hamming distance.
|
||||
INDEL: The Indel distance.
|
||||
"""
|
||||
|
||||
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
|
||||
LEVENSHTEIN = "levenshtein"
|
||||
JARO = "jaro"
|
||||
JARO_WINKLER = "jaro_winkler"
|
||||
HAMMING = "hamming"
|
||||
INDEL = "indel"
|
||||
|
||||
|
||||
class _RapidFuzzChainMixin(Chain):
|
||||
@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_metric(distance: str) -> Callable:
|
||||
def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
|
||||
"""
|
||||
Get the distance metric function based on the distance type.
|
||||
|
||||
@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain):
|
||||
Raises:
|
||||
ValueError: If the distance metric is invalid.
|
||||
"""
|
||||
rf_distance = _load_rapidfuzz()
|
||||
if distance == StringDistance.DAMERAU_LEVENSHTEIN:
|
||||
return rf_distance.DamerauLevenshtein.distance
|
||||
elif distance == StringDistance.LEVENSHTEIN:
|
||||
return rf_distance.Levenshtein.distance
|
||||
elif distance == StringDistance.JARO:
|
||||
return rf_distance.Jaro.distance
|
||||
elif distance == StringDistance.JARO_WINKLER:
|
||||
return rf_distance.JaroWinkler.distance
|
||||
from rapidfuzz import distance as rf_distance
|
||||
|
||||
module_map: Dict[str, Any] = {
|
||||
StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
|
||||
StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
|
||||
StringDistance.JARO: rf_distance.Jaro,
|
||||
StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
|
||||
StringDistance.HAMMING: rf_distance.Hamming,
|
||||
StringDistance.INDEL: rf_distance.Indel,
|
||||
}
|
||||
if distance not in module_map:
|
||||
raise ValueError(
|
||||
f"Invalid distance metric: {distance}"
|
||||
f"\nMust be one of: {list(StringDistance)}"
|
||||
)
|
||||
module = module_map[distance]
|
||||
if normalize_score:
|
||||
return module.normalized_distance
|
||||
else:
|
||||
raise ValueError(f"Invalid distance metric: {distance}")
|
||||
return module.distance
|
||||
|
||||
@property
|
||||
def metric(self) -> Callable:
|
||||
@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain):
|
||||
Returns:
|
||||
Callable: The distance metric function.
|
||||
"""
|
||||
return _RapidFuzzChainMixin._get_metric(self.distance)
|
||||
return _RapidFuzzChainMixin._get_metric(
|
||||
self.distance, normalize_score=self.normalize_score
|
||||
)
|
||||
|
||||
def compute_metric(self, a: str, b: str) -> float:
|
||||
"""
|
||||
@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain):
|
||||
Returns:
|
||||
float: The distance between the two strings.
|
||||
"""
|
||||
score = self.metric(a, b)
|
||||
if self.normalize_score and self.distance in (
|
||||
StringDistance.DAMERAU_LEVENSHTEIN,
|
||||
StringDistance.LEVENSHTEIN,
|
||||
):
|
||||
score = score / max(len(a), len(b))
|
||||
return score
|
||||
return self.metric(a, b)
|
||||
|
||||
|
||||
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
|
@ -30,8 +30,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None:
|
||||
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||
def test_zero_distance_pairwise(distance: StringDistance) -> None:
|
||||
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
||||
@pytest.mark.parametrize("normalize_score", [True, False])
|
||||
def test_zero_distance_pairwise(
|
||||
distance: StringDistance, normalize_score: bool
|
||||
) -> None:
|
||||
eval_chain = PairwiseStringDistanceEvalChain(
|
||||
distance=distance, normalize_score=normalize_score
|
||||
)
|
||||
string = "三人行则必有我师"
|
||||
result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string)
|
||||
assert "score" in result
|
||||
@ -49,3 +54,60 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None:
|
||||
)
|
||||
assert "score" in result
|
||||
assert result["score"] == 0
|
||||
|
||||
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||
@pytest.mark.parametrize("normalize_score", [True, False])
|
||||
def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> None:
|
||||
eval_chain = StringDistanceEvalChain(
|
||||
distance=distance, normalize_score=normalize_score
|
||||
)
|
||||
prediction = "I like to eat apples."
|
||||
reference = "I like apples."
|
||||
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
|
||||
assert "score" in result
|
||||
assert 0 < result["score"]
|
||||
if normalize_score:
|
||||
assert result["score"] < 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||
async def test_non_zero_distance_async(distance: StringDistance) -> None:
|
||||
eval_chain = StringDistanceEvalChain(distance=distance)
|
||||
prediction = "I like to eat apples."
|
||||
reference = "I like apples."
|
||||
result = await eval_chain.aevaluate_strings(
|
||||
prediction=prediction, reference=reference
|
||||
)
|
||||
assert "score" in result
|
||||
assert 0 < result["score"] < 1.0
|
||||
|
||||
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||
def test_non_zero_distance_pairwise(distance: StringDistance) -> None:
|
||||
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
||||
prediction = "I like to eat apples."
|
||||
reference = "I like apples."
|
||||
result = eval_chain.evaluate_string_pairs(
|
||||
prediction=prediction, prediction_b=reference
|
||||
)
|
||||
assert "score" in result
|
||||
assert 0 < result["score"] < 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.requires("rapidfuzz")
|
||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||
async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> None:
|
||||
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
||||
prediction = "I like to eat apples."
|
||||
reference = "I like apples."
|
||||
result = await eval_chain.aevaluate_string_pairs(
|
||||
prediction=prediction, prediction_b=reference
|
||||
)
|
||||
assert "score" in result
|
||||
assert 0 < result["score"] < 1.0
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test the loading function for evaluators."""
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
@ -26,24 +27,31 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"evaluator_type",
|
||||
"evaluator_types",
|
||||
[
|
||||
EvaluatorType.LABELED_CRITERIA,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||
EvaluatorType.QA,
|
||||
EvaluatorType.CONTEXT_QA,
|
||||
EvaluatorType.COT_QA,
|
||||
[EvaluatorType.LABELED_CRITERIA],
|
||||
[EvaluatorType.LABELED_PAIRWISE_STRING],
|
||||
[EvaluatorType.QA],
|
||||
[EvaluatorType.CONTEXT_QA],
|
||||
[EvaluatorType.COT_QA],
|
||||
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
|
||||
[
|
||||
EvaluatorType.COT_QA,
|
||||
EvaluatorType.LABELED_CRITERIA,
|
||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_eval_chain_requires_references(evaluator_type: EvaluatorType) -> None:
|
||||
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
|
||||
"""Test loading evaluators."""
|
||||
fake_llm = FakeLLM(
|
||||
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
|
||||
)
|
||||
evaluator = load_evaluators(
|
||||
[evaluator_type],
|
||||
evaluators = load_evaluators(
|
||||
evaluator_types,
|
||||
llm=fake_llm,
|
||||
)[0]
|
||||
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
|
||||
raise ValueError("Evaluator is not a [pairwise]string evaluator")
|
||||
assert evaluator.requires_reference
|
||||
)
|
||||
for evaluator in evaluators:
|
||||
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
|
||||
raise ValueError("Evaluator is not a [pairwise]string evaluator")
|
||||
assert evaluator.requires_reference
|
||||
|
Loading…
Reference in New Issue
Block a user