|
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import (
|
|
|
|
|
AbstractSet,
|
|
|
|
@ -27,6 +28,23 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
TS = TypeVar("TS", bound="TextSplitter")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
|
|
|
|
|
# Now that we have the separator, split the text
|
|
|
|
|
if separator:
|
|
|
|
|
if keep_separator:
|
|
|
|
|
# The parentheses in the pattern keep the delimiters in the result.
|
|
|
|
|
_splits = re.split(f"({separator})", text)
|
|
|
|
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
|
|
|
if len(_splits) % 2 == 0:
|
|
|
|
|
splits += _splits[-1:]
|
|
|
|
|
splits = [_splits[0]] + splits
|
|
|
|
|
else:
|
|
|
|
|
splits = text.split(separator)
|
|
|
|
|
else:
|
|
|
|
|
splits = list(text)
|
|
|
|
|
return [s for s in splits if s != ""]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
"""Interface for splitting text into chunks."""
|
|
|
|
|
|
|
|
|
@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
chunk_size: int = 4000,
|
|
|
|
|
chunk_overlap: int = 200,
|
|
|
|
|
length_function: Callable[[str], int] = len,
|
|
|
|
|
keep_separator: bool = False,
|
|
|
|
|
):
|
|
|
|
|
"""Create a new TextSplitter."""
|
|
|
|
|
"""Create a new TextSplitter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
chunk_size: Maximum size of chunks to return
|
|
|
|
|
chunk_overlap: Overlap in characters between chunks
|
|
|
|
|
length_function: Function that measures the length of given chunks
|
|
|
|
|
keep_separator: Whether or not to keep the separator in the chunks
|
|
|
|
|
"""
|
|
|
|
|
if chunk_overlap > chunk_size:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
|
|
@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
self._chunk_size = chunk_size
|
|
|
|
|
self._chunk_overlap = chunk_overlap
|
|
|
|
|
self._length_function = length_function
|
|
|
|
|
self._keep_separator = keep_separator
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter):
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
|
"""Split incoming text and return chunks."""
|
|
|
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
|
|
|
if self._separator:
|
|
|
|
|
splits = text.split(self._separator)
|
|
|
|
|
else:
|
|
|
|
|
splits = list(text)
|
|
|
|
|
return self._merge_splits(splits, self._separator)
|
|
|
|
|
splits = _split_text(text, self._separator, self._keep_separator)
|
|
|
|
|
_separator = "" if self._keep_separator else self._separator
|
|
|
|
|
return self._merge_splits(splits, _separator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenTextSplitter(TextSplitter):
|
|
|
|
@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|
|
|
|
that works.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
separators: Optional[List[str]] = None,
|
|
|
|
|
keep_separator: bool = True,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
):
|
|
|
|
|
"""Create a new TextSplitter."""
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
|
|
|
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
|
|
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
|
|
|
|
"""Split incoming text and return chunks."""
|
|
|
|
|
final_chunks = []
|
|
|
|
|
# Get appropriate separator to use
|
|
|
|
|
separator = self._separators[-1]
|
|
|
|
|
for _s in self._separators:
|
|
|
|
|
separator = separators[-1]
|
|
|
|
|
new_separators = None
|
|
|
|
|
for i, _s in enumerate(separators):
|
|
|
|
|
if _s == "":
|
|
|
|
|
separator = _s
|
|
|
|
|
break
|
|
|
|
|
if _s in text:
|
|
|
|
|
separator = _s
|
|
|
|
|
new_separators = separators[i + 1 :]
|
|
|
|
|
break
|
|
|
|
|
# Now that we have the separator, split the text
|
|
|
|
|
if separator:
|
|
|
|
|
splits = text.split(separator)
|
|
|
|
|
else:
|
|
|
|
|
splits = list(text)
|
|
|
|
|
|
|
|
|
|
splits = _split_text(text, separator, self._keep_separator)
|
|
|
|
|
# Now go merging things, recursively splitting longer texts.
|
|
|
|
|
_good_splits = []
|
|
|
|
|
_separator = "" if self._keep_separator else separator
|
|
|
|
|
for s in splits:
|
|
|
|
|
if self._length_function(s) < self._chunk_size:
|
|
|
|
|
_good_splits.append(s)
|
|
|
|
|
else:
|
|
|
|
|
if _good_splits:
|
|
|
|
|
merged_text = self._merge_splits(_good_splits, separator)
|
|
|
|
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|
|
|
|
final_chunks.extend(merged_text)
|
|
|
|
|
_good_splits = []
|
|
|
|
|
other_info = self.split_text(s)
|
|
|
|
|
if new_separators is None:
|
|
|
|
|
final_chunks.append(s)
|
|
|
|
|
else:
|
|
|
|
|
other_info = self._split_text(s, new_separators)
|
|
|
|
|
final_chunks.extend(other_info)
|
|
|
|
|
if _good_splits:
|
|
|
|
|
merged_text = self._merge_splits(_good_splits, separator)
|
|
|
|
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|
|
|
|
final_chunks.extend(merged_text)
|
|
|
|
|
return final_chunks
|
|
|
|
|
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
|
return self._split_text(text, self._separators)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NLTKTextSplitter(TextSplitter):
|
|
|
|
|
"""Implementation of splitting text that looks at sentences using NLTK."""
|
|
|
|
|