|
|
|
@ -43,12 +43,16 @@ class StringDistance(str, Enum):
|
|
|
|
|
LEVENSHTEIN: The Levenshtein distance.
|
|
|
|
|
JARO: The Jaro distance.
|
|
|
|
|
JARO_WINKLER: The Jaro-Winkler distance.
|
|
|
|
|
HAMMING: The Hamming distance.
|
|
|
|
|
INDEL: The Indel distance.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
DAMERAU_LEVENSHTEIN = "damerau_levenshtein"
|
|
|
|
|
LEVENSHTEIN = "levenshtein"
|
|
|
|
|
JARO = "jaro"
|
|
|
|
|
JARO_WINKLER = "jaro_winkler"
|
|
|
|
|
HAMMING = "hamming"
|
|
|
|
|
INDEL = "indel"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _RapidFuzzChainMixin(Chain):
|
|
|
|
@ -99,7 +103,7 @@ class _RapidFuzzChainMixin(Chain):
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _get_metric(distance: str) -> Callable:
|
|
|
|
|
def _get_metric(distance: str, normalize_score: bool = False) -> Callable:
|
|
|
|
|
"""
|
|
|
|
|
Get the distance metric function based on the distance type.
|
|
|
|
|
|
|
|
|
@ -112,17 +116,26 @@ class _RapidFuzzChainMixin(Chain):
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If the distance metric is invalid.
|
|
|
|
|
"""
|
|
|
|
|
rf_distance = _load_rapidfuzz()
|
|
|
|
|
if distance == StringDistance.DAMERAU_LEVENSHTEIN:
|
|
|
|
|
return rf_distance.DamerauLevenshtein.distance
|
|
|
|
|
elif distance == StringDistance.LEVENSHTEIN:
|
|
|
|
|
return rf_distance.Levenshtein.distance
|
|
|
|
|
elif distance == StringDistance.JARO:
|
|
|
|
|
return rf_distance.Jaro.distance
|
|
|
|
|
elif distance == StringDistance.JARO_WINKLER:
|
|
|
|
|
return rf_distance.JaroWinkler.distance
|
|
|
|
|
from rapidfuzz import distance as rf_distance
|
|
|
|
|
|
|
|
|
|
module_map: Dict[str, Any] = {
|
|
|
|
|
StringDistance.DAMERAU_LEVENSHTEIN: rf_distance.DamerauLevenshtein,
|
|
|
|
|
StringDistance.LEVENSHTEIN: rf_distance.Levenshtein,
|
|
|
|
|
StringDistance.JARO: rf_distance.Jaro,
|
|
|
|
|
StringDistance.JARO_WINKLER: rf_distance.JaroWinkler,
|
|
|
|
|
StringDistance.HAMMING: rf_distance.Hamming,
|
|
|
|
|
StringDistance.INDEL: rf_distance.Indel,
|
|
|
|
|
}
|
|
|
|
|
if distance not in module_map:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Invalid distance metric: {distance}"
|
|
|
|
|
f"\nMust be one of: {list(StringDistance)}"
|
|
|
|
|
)
|
|
|
|
|
module = module_map[distance]
|
|
|
|
|
if normalize_score:
|
|
|
|
|
return module.normalized_distance
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Invalid distance metric: {distance}")
|
|
|
|
|
return module.distance
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def metric(self) -> Callable:
|
|
|
|
@ -132,7 +145,9 @@ class _RapidFuzzChainMixin(Chain):
|
|
|
|
|
Returns:
|
|
|
|
|
Callable: The distance metric function.
|
|
|
|
|
"""
|
|
|
|
|
return _RapidFuzzChainMixin._get_metric(self.distance)
|
|
|
|
|
return _RapidFuzzChainMixin._get_metric(
|
|
|
|
|
self.distance, normalize_score=self.normalize_score
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def compute_metric(self, a: str, b: str) -> float:
|
|
|
|
|
"""
|
|
|
|
@ -145,13 +160,7 @@ class _RapidFuzzChainMixin(Chain):
|
|
|
|
|
Returns:
|
|
|
|
|
float: The distance between the two strings.
|
|
|
|
|
"""
|
|
|
|
|
score = self.metric(a, b)
|
|
|
|
|
if self.normalize_score and self.distance in (
|
|
|
|
|
StringDistance.DAMERAU_LEVENSHTEIN,
|
|
|
|
|
StringDistance.LEVENSHTEIN,
|
|
|
|
|
):
|
|
|
|
|
score = score / max(len(a), len(b))
|
|
|
|
|
return score
|
|
|
|
|
return self.metric(a, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
|
|
|
|