SemanticChunker : Feature Addition ("Semantic Splitting with gradient") (#22895)

```SemanticChunker``` currently provide three methods to split the texts semantically:
- percentile
- standard_deviation
- interquartile

I propose new method ```gradient```. In this method, the gradient of distance is used to split chunks along with the percentile method (technically) . This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data.
I have tested this merge on a set of 10 domain specific documents (mostly legal).

Details : 
    - **Issue:** Improvement
    - **Dependencies:** NA
    - **Twitter handle:** [x.com/prajapat_ravi](https://x.com/prajapat_ravi)


@hwchase17

---------

Co-authored-by: Raviraj Prajapat <raviraj.prajapat@sirionlabs.com>
Co-authored-by: isaac hershenson <ihershenson@hmc.edu>
pull/23068/head
Raviraj 4 months ago committed by GitHub
parent 55705c0f5e
commit 858ce264ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -297,13 +297,67 @@
"print(len(docs))" "print(len(docs))"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"### Gradient\n",
"\n",
"In this method, the gradient of distance is used to split chunks along with the percentile method.\n",
"This method is useful when chunks are highly correlated with each other or specific to a domain e.g. legal or medical. The idea is to apply anomaly detection on gradient array so that the distribution become wider and easy to identify boundaries in highly semantic data."
],
"metadata": {
"collapsed": false
},
"id": "423c6e099e94ca69"
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "b1f65472", "id": "b1f65472",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": [
"text_splitter = SemanticChunker(\n",
" OpenAIEmbeddings(), breakpoint_threshold_type=\"gradient\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman.\n"
]
}
],
"source": [
"docs = text_splitter.create_documents([state_of_the_union])\n",
"print(docs[0].page_content)"
],
"metadata": {},
"id": "e9f393d316ce1f6c"
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"26\n"
]
}
],
"source": [
"print(len(docs))"
],
"metadata": {},
"id": "a407cd57f02a0db4"
} }
], ],
"metadata": { "metadata": {

@ -84,11 +84,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
return distances, sentences return distances, sentences
BreakpointThresholdType = Literal["percentile", "standard_deviation", "interquartile"] BreakpointThresholdType = Literal[
"percentile", "standard_deviation", "interquartile", "gradient"
]
BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = { BREAKPOINT_DEFAULTS: Dict[BreakpointThresholdType, float] = {
"percentile": 95, "percentile": 95,
"standard_deviation": 3, "standard_deviation": 3,
"interquartile": 1.5, "interquartile": 1.5,
"gradient": 95,
} }
@ -127,23 +130,34 @@ class SemanticChunker(BaseDocumentTransformer):
else: else:
self.breakpoint_threshold_amount = breakpoint_threshold_amount self.breakpoint_threshold_amount = breakpoint_threshold_amount
def _calculate_breakpoint_threshold(self, distances: List[float]) -> float: def _calculate_breakpoint_threshold(
self, distances: List[float]
) -> Tuple[float, List[float]]:
if self.breakpoint_threshold_type == "percentile": if self.breakpoint_threshold_type == "percentile":
return cast( return cast(
float, float,
np.percentile(distances, self.breakpoint_threshold_amount), np.percentile(distances, self.breakpoint_threshold_amount),
) ), distances
elif self.breakpoint_threshold_type == "standard_deviation": elif self.breakpoint_threshold_type == "standard_deviation":
return cast( return cast(
float, float,
np.mean(distances) np.mean(distances)
+ self.breakpoint_threshold_amount * np.std(distances), + self.breakpoint_threshold_amount * np.std(distances),
) ), distances
elif self.breakpoint_threshold_type == "interquartile": elif self.breakpoint_threshold_type == "interquartile":
q1, q3 = np.percentile(distances, [25, 75]) q1, q3 = np.percentile(distances, [25, 75])
iqr = q3 - q1 iqr = q3 - q1
return np.mean(distances) + self.breakpoint_threshold_amount * iqr return np.mean(
distances
) + self.breakpoint_threshold_amount * iqr, distances
elif self.breakpoint_threshold_type == "gradient":
# Calculate the threshold based on the distribution of gradient of distance array. # noqa: E501
distance_gradient = np.gradient(distances, range(0, len(distances)))
return cast(
float,
np.percentile(distance_gradient, self.breakpoint_threshold_amount),
), distance_gradient
else: else:
raise ValueError( raise ValueError(
f"Got unexpected `breakpoint_threshold_type`: " f"Got unexpected `breakpoint_threshold_type`: "
@ -201,13 +215,17 @@ class SemanticChunker(BaseDocumentTransformer):
distances, sentences = self._calculate_sentence_distances(single_sentences_list) distances, sentences = self._calculate_sentence_distances(single_sentences_list)
if self.number_of_chunks is not None: if self.number_of_chunks is not None:
breakpoint_distance_threshold = self._threshold_from_clusters(distances) breakpoint_distance_threshold = self._threshold_from_clusters(distances)
breakpoint_array = distances
else: else:
breakpoint_distance_threshold = self._calculate_breakpoint_threshold( (
distances breakpoint_distance_threshold,
) breakpoint_array,
) = self._calculate_breakpoint_threshold(distances)
indices_above_thresh = [ indices_above_thresh = [
i for i, x in enumerate(distances) if x > breakpoint_distance_threshold i
for i, x in enumerate(breakpoint_array)
if x > breakpoint_distance_threshold
] ]
chunks = [] chunks = []

Loading…
Cancel
Save