From 983678dedcd9b13d9d92df77d7c3cfbd702537b0 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sun, 6 Aug 2023 14:05:00 -0700 Subject: [PATCH] Add Dist Metrics for String Distance Evaluation (#8837) Co-authored-by: shibuiwilliam --- .../evaluation/string_distance/base.py | 47 +++++++------ .../evaluation/string_distance/test_base.py | 66 ++++++++++++++++++- .../unit_tests/evaluation/test_loading.py | 34 ++++++---- 3 files changed, 113 insertions(+), 34 deletions(-) diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index 57e195d8d1..57f2958fe5 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -43,12 +43,16 @@ class StringDistance(str, Enum): LEVENSHTEIN: The Levenshtein distance. JARO: The Jaro distance. JARO_WINKLER: The Jaro-Winkler distance. + HAMMING: The Hamming distance. + INDEL: The Indel distance. """ DAMERAU_LEVENSHTEIN = "damerau_levenshtein" LEVENSHTEIN = "levenshtein" JARO = "jaro" JARO_WINKLER = "jaro_winkler" + HAMMING = "hamming" + INDEL = "indel" class _RapidFuzzChainMixin(Chain): @@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain): return result @staticmethod - def _get_metric(distance: str) -> Callable: + def _get_metric(distance: str, normalize_score: bool = False) -> Callable: """ Get the distance metric function based on the distance type. @@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain): Raises: ValueError: If the distance metric is invalid. """ - rf_distance = _load_rapidfuzz() - if distance == StringDistance.DAMERAU_LEVENSHTEIN: - return rf_distance.DamerauLevenshtein.distance - elif distance == StringDistance.LEVENSHTEIN: - return rf_distance.Levenshtein.distance - elif distance == StringDistance.JARO: - return rf_distance.Jaro.distance - elif distance == StringDistance.JARO_WINKLER: - return rf_distance.JaroWinkler.distance + from rapidfuzz import distance as rf_distance + + module_map: Dict[str, Any] = { + StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein, + StringDistance.LEVENSHTEIN: rf_distance.Levenshtein, + StringDistance.JARO: rf_distance.Jaro, + StringDistance.JARO_WINKLER: rf_distance.JaroWinkler, + StringDistance.HAMMING: rf_distance.Hamming, + StringDistance.INDEL: rf_distance.Indel, + } + if distance not in module_map: + raise ValueError( + f"Invalid distance metric: {distance}" + f"\nMust be one of: {list(StringDistance)}" + ) + module = module_map[distance] + if normalize_score: + return module.normalized_distance else: - raise ValueError(f"Invalid distance metric: {distance}") + return module.distance @property def metric(self) -> Callable: @@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain): Returns: Callable: The distance metric function. """ - return _RapidFuzzChainMixin._get_metric(self.distance) + return _RapidFuzzChainMixin._get_metric( + self.distance, normalize_score=self.normalize_score + ) def compute_metric(self, a: str, b: str) -> float: """ @@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain): Returns: float: The distance between the two strings. """ - score = self.metric(a, b) - if self.normalize_score and self.distance in ( - StringDistance.DAMERAU_LEVENSHTEIN, - StringDistance.LEVENSHTEIN, - ): - score = score / max(len(a), len(b)) - return score + return self.metric(a, b) class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin): diff --git a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py index 8fe35c08e4..eff632f454 100644 --- a/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py +++ b/libs/langchain/tests/unit_tests/evaluation/string_distance/test_base.py @@ -30,8 +30,13 @@ async def test_zero_distance_async(distance: StringDistance) -> None: @pytest.mark.requires("rapidfuzz") @pytest.mark.parametrize("distance", list(StringDistance)) -def test_zero_distance_pairwise(distance: StringDistance) -> None: - eval_chain = PairwiseStringDistanceEvalChain(distance=distance) +@pytest.mark.parametrize("normalize_score", [True, False]) +def test_zero_distance_pairwise( + distance: StringDistance, normalize_score: bool +) -> None: + eval_chain = PairwiseStringDistanceEvalChain( + distance=distance, normalize_score=normalize_score + ) string = "三人行则必有我师" result = eval_chain.evaluate_string_pairs(prediction=string, prediction_b=string) assert "score" in result @@ -49,3 +54,60 @@ async def test_zero_distance_pairwise_async(distance: StringDistance) -> None: ) assert "score" in result assert result["score"] == 0 + + +@pytest.mark.requires("rapidfuzz") +@pytest.mark.parametrize("distance", list(StringDistance)) +@pytest.mark.parametrize("normalize_score", [True, False]) +def test_non_zero_distance(distance: StringDistance, normalize_score: bool) -> None: + eval_chain = StringDistanceEvalChain( + distance=distance, normalize_score=normalize_score + ) + prediction = "I like to eat apples." + reference = "I like apples." + result = eval_chain.evaluate_strings(prediction=prediction, reference=reference) + assert "score" in result + assert 0 < result["score"] + if normalize_score: + assert result["score"] < 1.0 + + +@pytest.mark.asyncio +@pytest.mark.requires("rapidfuzz") +@pytest.mark.parametrize("distance", list(StringDistance)) +async def test_non_zero_distance_async(distance: StringDistance) -> None: + eval_chain = StringDistanceEvalChain(distance=distance) + prediction = "I like to eat apples." + reference = "I like apples." + result = await eval_chain.aevaluate_strings( + prediction=prediction, reference=reference + ) + assert "score" in result + assert 0 < result["score"] < 1.0 + + +@pytest.mark.requires("rapidfuzz") +@pytest.mark.parametrize("distance", list(StringDistance)) +def test_non_zero_distance_pairwise(distance: StringDistance) -> None: + eval_chain = PairwiseStringDistanceEvalChain(distance=distance) + prediction = "I like to eat apples." + reference = "I like apples." + result = eval_chain.evaluate_string_pairs( + prediction=prediction, prediction_b=reference + ) + assert "score" in result + assert 0 < result["score"] < 1.0 + + +@pytest.mark.asyncio +@pytest.mark.requires("rapidfuzz") +@pytest.mark.parametrize("distance", list(StringDistance)) +async def test_non_zero_distance_pairwise_async(distance: StringDistance) -> None: + eval_chain = PairwiseStringDistanceEvalChain(distance=distance) + prediction = "I like to eat apples." + reference = "I like apples." + result = await eval_chain.aevaluate_string_pairs( + prediction=prediction, prediction_b=reference + ) + assert "score" in result + assert 0 < result["score"] < 1.0 diff --git a/libs/langchain/tests/unit_tests/evaluation/test_loading.py b/libs/langchain/tests/unit_tests/evaluation/test_loading.py index 160d5ee0d9..11715e9250 100644 --- a/libs/langchain/tests/unit_tests/evaluation/test_loading.py +++ b/libs/langchain/tests/unit_tests/evaluation/test_loading.py @@ -1,4 +1,5 @@ """Test the loading function for evaluators.""" +from typing import List import pytest @@ -26,24 +27,31 @@ def test_load_evaluators(evaluator_type: EvaluatorType) -> None: @pytest.mark.parametrize( - "evaluator_type", + "evaluator_types", [ - EvaluatorType.LABELED_CRITERIA, - EvaluatorType.LABELED_PAIRWISE_STRING, - EvaluatorType.QA, - EvaluatorType.CONTEXT_QA, - EvaluatorType.COT_QA, + [EvaluatorType.LABELED_CRITERIA], + [EvaluatorType.LABELED_PAIRWISE_STRING], + [EvaluatorType.QA], + [EvaluatorType.CONTEXT_QA], + [EvaluatorType.COT_QA], + [EvaluatorType.COT_QA, EvaluatorType.LABELED_CRITERIA], + [ + EvaluatorType.COT_QA, + EvaluatorType.LABELED_CRITERIA, + EvaluatorType.LABELED_PAIRWISE_STRING, + ], ], ) -def test_eval_chain_requires_references(evaluator_type: EvaluatorType) -> None: +def test_eval_chain_requires_references(evaluator_types: List[EvaluatorType]) -> None: """Test loading evaluators.""" fake_llm = FakeLLM( queries={"text": "The meaning of life\nCORRECT"}, sequential_responses=True ) - evaluator = load_evaluators( - [evaluator_type], + evaluators = load_evaluators( + evaluator_types, llm=fake_llm, - )[0] - if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)): - raise ValueError("Evaluator is not a [pairwise]string evaluator") - assert evaluator.requires_reference + ) + for evaluator in evaluators: + if not isinstance(evaluator, (StringEvaluator, PairwiseStringEvaluator)): + raise ValueError("Evaluator is not a [pairwise]string evaluator") + assert evaluator.requires_reference