mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add Dist Metrics for String Distance Evaluation (#8837)
Co-authored-by: shibuiwilliam <shibuiyusuke@gmail.com>
This commit is contained in:
parent
f76d50d8dc
commit
983678dedc
@ -43,12 +43,16 @@ class StringDistance(str, Enum):
|
|||||||
LEVENSHTEIN: The Levenshtein distance.
|
LEVENSHTEIN: The Levenshtein distance.
|
||||||
JARO: The Jaro distance.
|
JARO: The Jaro distance.
|
||||||
JARO_WINKLER: The Jaro-Winkler distance.
|
JARO_WINKLER: The Jaro-Winkler distance.
|
||||||
|
HAMMING: The Hamming distance.
|
||||||
|
INDEL: The Indel distance.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
|
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
|
||||||
LEVENSHTEIN = "levenshtein"
|
LEVENSHTEIN = "levenshtein"
|
||||||
JARO = "jaro"
|
JARO = "jaro"
|
||||||
JARO_WINKLER = "jaro_winkler"
|
JARO_WINKLER = "jaro_winkler"
|
||||||
|
HAMMING = "hamming"
|
||||||
|
INDEL = "indel"
|
||||||
|
|
||||||
|
|
||||||
class _RapidFuzzChainMixin(Chain):
|
class _RapidFuzzChainMixin(Chain):
|
||||||
@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_metric(distance: str) -> Callable:
|
def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
|
||||||
"""
|
"""
|
||||||
Get the distance metric function based on the distance type.
|
Get the distance metric function based on the distance type.
|
||||||
|
|
||||||
@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the distance metric is invalid.
|
ValueError: If the distance metric is invalid.
|
||||||
"""
|
"""
|
||||||
rf_distance = _load_rapidfuzz()
|
from rapidfuzz import distance as rf_distance
|
||||||
if distance == StringDistance.DAMERAU_LEVENSHTEIN:
|
|
||||||
return rf_distance.DamerauLevenshtein.distance
|
module_map: Dict[str, Any] = {
|
||||||
elif distance == StringDistance.LEVENSHTEIN:
|
StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
|
||||||
return rf_distance.Levenshtein.distance
|
StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
|
||||||
elif distance == StringDistance.JARO:
|
StringDistance.JARO: rf_distance.Jaro,
|
||||||
return rf_distance.Jaro.distance
|
StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
|
||||||
elif distance == StringDistance.JARO_WINKLER:
|
StringDistance.HAMMING: rf_distance.Hamming,
|
||||||
return rf_distance.JaroWinkler.distance
|
StringDistance.INDEL: rf_distance.Indel,
|
||||||
|
}
|
||||||
|
if distance not in module_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid distance metric: {distance}"
|
||||||
|
f"\nMust be one of: {list(StringDistance)}"
|
||||||
|
)
|
||||||
|
module = module_map[distance]
|
||||||
|
if normalize_score:
|
||||||
|
return module.normalized_distance
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid distance metric: {distance}")
|
return module.distance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def metric(self) -> Callable:
|
def metric(self) -> Callable:
|
||||||
@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
Returns:
|
Returns:
|
||||||
Callable: The distance metric function.
|
Callable: The distance metric function.
|
||||||
"""
|
"""
|
||||||
return _RapidFuzzChainMixin._get_metric(self.distance)
|
return _RapidFuzzChainMixin._get_metric(
|
||||||
|
self.distance, normalize_score=self.normalize_score
|
||||||
|
)
|
||||||
|
|
||||||
def compute_metric(self, a: str, b: str) -> float:
|
def compute_metric(self, a: str, b: str) -> float:
|
||||||
"""
|
"""
|
||||||
@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain):
|
|||||||
Returns:
|
Returns:
|
||||||
float: The distance between the two strings.
|
float: The distance between the two strings.
|
||||||
"""
|
"""
|
||||||
score = self.metric(a, b)
|
return self.metric(a, b)
|
||||||
if self.normalize_score and self.distance in (
|
|
||||||
StringDistance.DAMERAU_LEVENSHTEIN,
|
|
||||||
StringDistance.LEVENSHTEIN,
|
|
||||||
):
|
|
||||||
score = score / max(len(a), len(b))
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||||
|
@ -30,8 +30,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None:
|
|||||||
|
|
||||||
@pytest.mark.requires("rapidfuzz")
|
@pytest.mark.requires("rapidfuzz")
|
||||||
@pytest.mark.parametrize("distance", list(StringDistance))
|
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||||
def test_zero_distance_pairwise(distance: StringDistance) -> None:
|
@pytest.mark.parametrize("normalize_score", [True, False])
|
||||||
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
def test_zero_distance_pairwise(
|
||||||
|
distance: StringDistance, normalize_score: bool
|
||||||
|
) -> None:
|
||||||
|
eval_chain = PairwiseStringDistanceEvalChain(
|
||||||
|
distance=distance, normalize_score=normalize_score
|
||||||
|
)
|
||||||
string = "三人行则必有我师"
|
string = "三人行则必有我师"
|
||||||
result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string)
|
result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string)
|
||||||
assert "score" in result
|
assert "score" in result
|
||||||
@ -49,3 +54,60 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None:
|
|||||||
)
|
)
|
||||||
assert "score" in result
|
assert "score" in result
|
||||||
assert result["score"] == 0
|
assert result["score"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("rapidfuzz")
|
||||||
|
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||||
|
@pytest.mark.parametrize("normalize_score", [True, False])
|
||||||
|
def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> None:
|
||||||
|
eval_chain = StringDistanceEvalChain(
|
||||||
|
distance=distance, normalize_score=normalize_score
|
||||||
|
)
|
||||||
|
prediction = "I like to eat apples."
|
||||||
|
reference = "I like apples."
|
||||||
|
result = eval_chain.evaluate_strings(prediction=prediction, reference=reference)
|
||||||
|
assert "score" in result
|
||||||
|
assert 0 < result["score"]
|
||||||
|
if normalize_score:
|
||||||
|
assert result["score"] < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.requires("rapidfuzz")
|
||||||
|
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||||
|
async def test_non_zero_distance_async(distance: StringDistance) -> None:
|
||||||
|
eval_chain = StringDistanceEvalChain(distance=distance)
|
||||||
|
prediction = "I like to eat apples."
|
||||||
|
reference = "I like apples."
|
||||||
|
result = await eval_chain.aevaluate_strings(
|
||||||
|
prediction=prediction, reference=reference
|
||||||
|
)
|
||||||
|
assert "score" in result
|
||||||
|
assert 0 < result["score"] < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("rapidfuzz")
|
||||||
|
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||||
|
def test_non_zero_distance_pairwise(distance: StringDistance) -> None:
|
||||||
|
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
||||||
|
prediction = "I like to eat apples."
|
||||||
|
reference = "I like apples."
|
||||||
|
result = eval_chain.evaluate_string_pairs(
|
||||||
|
prediction=prediction, prediction_b=reference
|
||||||
|
)
|
||||||
|
assert "score" in result
|
||||||
|
assert 0 < result["score"] < 1.0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.requires("rapidfuzz")
|
||||||
|
@pytest.mark.parametrize("distance", list(StringDistance))
|
||||||
|
async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> None:
|
||||||
|
eval_chain = PairwiseStringDistanceEvalChain(distance=distance)
|
||||||
|
prediction = "I like to eat apples."
|
||||||
|
reference = "I like apples."
|
||||||
|
result = await eval_chain.aevaluate_string_pairs(
|
||||||
|
prediction=prediction, prediction_b=reference
|
||||||
|
)
|
||||||
|
assert "score" in result
|
||||||
|
assert 0 < result["score"] < 1.0
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test the loading function for evaluators."""
|
"""Test the loading function for evaluators."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -26,24 +27,31 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"evaluator_type",
|
"evaluator_types",
|
||||||
[
|
[
|
||||||
EvaluatorType.LABELED_CRITERIA,
|
[EvaluatorType.LABELED_CRITERIA],
|
||||||
EvaluatorType.LABELED_PAIRWISE_STRING,
|
[EvaluatorType.LABELED_PAIRWISE_STRING],
|
||||||
EvaluatorType.QA,
|
[EvaluatorType.QA],
|
||||||
EvaluatorType.CONTEXT_QA,
|
[EvaluatorType.CONTEXT_QA],
|
||||||
EvaluatorType.COT_QA,
|
[EvaluatorType.COT_QA],
|
||||||
|
[EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA],
|
||||||
|
[
|
||||||
|
EvaluatorType.COT_QA,
|
||||||
|
EvaluatorType.LABELED_CRITERIA,
|
||||||
|
EvaluatorType.LABELED_PAIRWISE_STRING,
|
||||||
|
],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_eval_chain_requires_references(evaluator_type: EvaluatorType) -> None:
|
def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None:
|
||||||
"""Test loading evaluators."""
|
"""Test loading evaluators."""
|
||||||
fake_llm = FakeLLM(
|
fake_llm = FakeLLM(
|
||||||
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
|
queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True
|
||||||
)
|
)
|
||||||
evaluator = load_evaluators(
|
evaluators = load_evaluators(
|
||||||
[evaluator_type],
|
evaluator_types,
|
||||||
llm=fake_llm,
|
llm=fake_llm,
|
||||||
)[0]
|
)
|
||||||
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
|
for evaluator in evaluators:
|
||||||
raise ValueError("Evaluator is not a [pairwise]string evaluator")
|
if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)):
|
||||||
assert evaluator.requires_reference
|
raise ValueError("Evaluator is not a [pairwise]string evaluator")
|
||||||
|
assert evaluator.requires_reference
|
||||||
|
Loading…
Reference in New Issue
Block a user