@ -1,6 +1,6 @@
import copy
import re
from typing import Any , Iterable, List , Optional, Sequence , Tuple
from typing import Any , Dict, Iterable, List , Literal, Optional, Sequence , Tuple , cast
import numpy as np
from langchain_community . utils . math import (
@ -83,6 +83,14 @@ def calculate_cosine_distances(sentences: List[dict]) -> Tuple[List[float], List
return distances , sentences
BreakpointThresholdType = Literal [ " percentile " , " standard_deviation " , " interquartile " ]
BREAKPOINT_DEFAULTS : Dict [ BreakpointThresholdType , float ] = {
" percentile " : 95 ,
" standard_deviation " : 3 ,
" interquartile " : 1.5 ,
}
class SemanticChunker ( BaseDocumentTransformer ) :
""" Split the text based on semantic similarity.
@ -95,42 +103,110 @@ class SemanticChunker(BaseDocumentTransformer):
sentences , and then merges one that are similar in the embedding space .
"""
def __init__ ( self , embeddings : Embeddings , add_start_index : bool = False ) :
def __init__ (
self ,
embeddings : Embeddings ,
add_start_index : bool = False ,
breakpoint_threshold_type : BreakpointThresholdType = " percentile " ,
breakpoint_threshold_amount : Optional [ float ] = None ,
number_of_chunks : Optional [ int ] = None ,
) :
self . _add_start_index = add_start_index
self . embeddings = embeddings
def split_text ( self , text : str ) - > List [ str ] :
self . breakpoint_threshold_type = breakpoint_threshold_type
self . number_of_chunks = number_of_chunks
if breakpoint_threshold_amount is None :
self . breakpoint_threshold_amount = BREAKPOINT_DEFAULTS [
breakpoint_threshold_type
]
else :
self . breakpoint_threshold_amount = breakpoint_threshold_amount
def _calculate_breakpoint_threshold ( self , distances : List [ float ] ) - > float :
if self . breakpoint_threshold_type == " percentile " :
return cast (
float ,
np . percentile ( distances , self . breakpoint_threshold_amount ) ,
)
elif self . breakpoint_threshold_type == " standard_deviation " :
return cast (
float ,
np . mean ( distances )
+ self . breakpoint_threshold_amount * np . std ( distances ) ,
)
elif self . breakpoint_threshold_type == " interquartile " :
q1 , q3 = np . percentile ( distances , [ 25 , 75 ] )
iqr = q3 - q1
return np . mean ( distances ) + self . breakpoint_threshold_amount * iqr
else :
raise ValueError (
f " Got unexpected `breakpoint_threshold_type`: "
f " { self . breakpoint_threshold_type } "
)
def _threshold_from_clusters ( self , distances : List [ float ] ) - > float :
"""
Calculate the threshold based on the number of chunks .
Inverse of percentile method .
"""
if self . number_of_chunks is None :
raise ValueError (
" This should never be called if `number_of_chunks` is None. "
)
x1 , y1 = len ( distances ) , 0.0
x2 , y2 = 1.0 , 100.0
x = max ( min ( self . number_of_chunks , x1 ) , x2 )
# Linear interpolation formula
y = y1 + ( ( y2 - y1 ) / ( x2 - x1 ) ) * ( x - x1 )
y = min ( max ( y , 0 ) , 100 )
return cast ( float , np . percentile ( distances , y ) )
def _calculate_sentence_distances (
self , single_sentences_list : List [ str ]
) - > Tuple [ List [ float ] , List [ dict ] ] :
""" Split text into multiple components. """
# Splitting the essay on '.', '?', and '!'
single_sentences_list = re . split ( r " (?<=[.?!]) \ s+ " , text )
# having len(single_sentences_list) == 1 would cause the following
# np.percentile to fail.
if len ( single_sentences_list ) == 1 :
return single_sentences_list
sentences = [
_sentences = [
{ " sentence " : x , " index " : i } for i , x in enumerate ( single_sentences_list )
]
sentences = combine_sentences ( sentences )
sentences = combine_sentences ( _sentences )
embeddings = self . embeddings . embed_documents (
[ x [ " combined_sentence " ] for x in sentences ]
)
for i , sentence in enumerate ( sentences ) :
sentence [ " combined_sentence_embedding " ] = embeddings [ i ]
distances , sentences = calculate_cosine_distances ( sentences )
start_index = 0
# Create a list to hold the grouped sentences
chunks = [ ]
breakpoint_percentile_threshold = 95
breakpoint_distance_threshold = np . percentile (
distances , breakpoint_percentile_threshold
) # If you want more chunks, lower the percentile cutoff
return calculate_cosine_distances ( sentences )
def split_text (
self ,
text : str ,
) - > List [ str ] :
# Splitting the essay on '.', '?', and '!'
single_sentences_list = re . split ( r " (?<=[.?!]) \ s+ " , text )
# having len(single_sentences_list) == 1 would cause the following
# np.percentile to fail.
if len ( single_sentences_list ) == 1 :
return single_sentences_list
distances , sentences = self . _calculate_sentence_distances ( single_sentences_list )
if self . number_of_chunks is not None :
breakpoint_distance_threshold = self . _threshold_from_clusters ( distances )
else :
breakpoint_distance_threshold = self . _calculate_breakpoint_threshold (
distances
)
indices_above_thresh = [
i for i , x in enumerate ( distances ) if x > breakpoint_distance_threshold
] # The indices of those breakpoints on your list
]
chunks = [ ]
start_index = 0
# Iterate through the breakpoints to slice the sentences
for index in indices_above_thresh :