Experimental: Add other threshold types to SemanticChunker (#16807)

**Description**
Adding different threshold types to the semantic chunker. I’ve had much
better and predictable performance when using standard deviations
instead of percentiles.


![image](https://github.com/langchain-ai/langchain/assets/44395485/066e84a8-460e-4da5-9fa1-4ff79a1941c5)

For all the documents I’ve tried, the distribution of distances look
similar to the above: positively skewed normal distribution. All skews
I’ve seen are less than 1 so that explains why standard deviations
perform well, but I’ve included IQR if anyone wants something more
robust.

Also, using the percentile method backwards, you can declare the number
of clusters and use semantic chunking to get an ‘optimal’ splitting.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/18150/head
matt haigh 7 months ago committed by GitHub
parent ce682f5a09
commit a4896da2a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

File diff suppressed because one or more lines are too long

@ -1,6 +1,6 @@
import copy
import re
from typing import Any, Iterable, List, Optional, Sequence, Tuple
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, cast
import numpy as np
from langchain_community.utils.math import (
@ -83,6 +83,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
return distances, sentences
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"]
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
"percentile": 95,
"standard_deviation": 3,
"interquartile": 1.5,
}
class SemanticChunker(BaseDocumentTransformer):
"""Split the text based on semantic similarity.
@ -95,42 +103,110 @@ class SemanticChunker(BaseDocumentTransformer):
sentences, and then merges one that are similar in the embedding space.
"""
def __init__(self, embeddings: Embeddings, add_start_index: bool = False):
def __init__(
self,
embeddings: Embeddings,
add_start_index: bool = False,
breakpoint_threshold_type: BreakpointThresholdType = "percentile",
breakpoint_threshold_amount: Optional[float] = None,
number_of_chunks: Optional[int] = None,
):
self._add_start_index = add_start_index
self.embeddings = embeddings
def split_text(self, text: str) -> List[str]:
self.breakpoint_threshold_type = breakpoint_threshold_type
self.number_of_chunks = number_of_chunks
if breakpoint_threshold_amount is None:
self.breakpoint_threshold_amount = BREAKPOINT_DEFAULTS[
breakpoint_threshold_type
]
else:
self.breakpoint_threshold_amount = breakpoint_threshold_amount
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float:
if self.breakpoint_threshold_type == "percentile":
return cast(
float,
np.percentile(distances, self.breakpoint_threshold_amount),
)
elif self.breakpoint_threshold_type == "standard_deviation":
return cast(
float,
np.mean(distances)
+ self.breakpoint_threshold_amount * np.std(distances),
)
elif self.breakpoint_threshold_type == "interquartile":
q1, q3 = np.percentile(distances, [25, 75])
iqr = q3 - q1
return np.mean(distances) + self.breakpoint_threshold_amount * iqr
else:
raise ValueError(
f"Got unexpected `breakpoint_threshold_type`: "
f"{self.breakpoint_threshold_type}"
)
def _threshold_from_clusters(self, distances: List[float]) -> float:
"""
Calculate the threshold based on the number of chunks.
Inverse of percentile method.
"""
if self.number_of_chunks is None:
raise ValueError(
"This should never be called if `number_of_chunks` is None."
)
x1, y1 = len(distances), 0.0
x2, y2 = 1.0, 100.0
x = max(min(self.number_of_chunks, x1), x2)
# Linear interpolation formula
y = y1 + ((y2 - y1) / (x2 - x1)) * (x - x1)
y = min(max(y, 0), 100)
return cast(float, np.percentile(distances, y))
def _calculate_sentence_distances(
self, single_sentences_list: List[str]
) -> Tuple[List[float], List[dict]]:
"""Split text into multiple components."""
# Splitting the essay on '.', '?', and '!'
single_sentences_list = re.split(r"(?<=[.?!])\s+", text)
# having len(single_sentences_list) == 1 would cause the following
# np.percentile to fail.
if len(single_sentences_list) == 1:
return single_sentences_list
sentences = [
_sentences = [
{"sentence": x, "index": i} for i, x in enumerate(single_sentences_list)
]
sentences = combine_sentences(sentences)
sentences = combine_sentences(_sentences)
embeddings = self.embeddings.embed_documents(
[x["combined_sentence"] for x in sentences]
)
for i, sentence in enumerate(sentences):
sentence["combined_sentence_embedding"] = embeddings[i]
distances, sentences = calculate_cosine_distances(sentences)
start_index = 0
# Create a list to hold the grouped sentences
chunks = []
breakpoint_percentile_threshold = 95
breakpoint_distance_threshold = np.percentile(
distances, breakpoint_percentile_threshold
) # If you want more chunks, lower the percentile cutoff
return calculate_cosine_distances(sentences)
def split_text(
self,
text: str,
) -> List[str]:
# Splitting the essay on '.', '?', and '!'
single_sentences_list = re.split(r"(?<=[.?!])\s+", text)
# having len(single_sentences_list) == 1 would cause the following
# np.percentile to fail.
if len(single_sentences_list) == 1:
return single_sentences_list
distances, sentences = self._calculate_sentence_distances(single_sentences_list)
if self.number_of_chunks is not None:
breakpoint_distance_threshold = self._threshold_from_clusters(distances)
else:
breakpoint_distance_threshold = self._calculate_breakpoint_threshold(
distances
)
indices_above_thresh = [
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold
] # The indices of those breakpoints on your list
]
chunks = []
start_index = 0
# Iterate through the breakpoints to slice the sentences
for index in indices_above_thresh:

Loading…
Cancel
Save