smart text splitter (#530)

smart text splitter that iteratively tries different separators until it
works!
harrison/pinecone-try-except
Harrison Chase 1 year ago committed by GitHub
parent 8dfad874a2
commit 1192cc0767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -90,6 +90,61 @@
"print(texts[0])"
]
},
{
"cell_type": "markdown",
"id": "1be00b73",
"metadata": {},
"source": [
"## Recursive Character Text Splitting\n",
"Sometimes, it's not enough to split on just one character. This text splitter uses a whole list of characters and recursive splits them down until they are under the limit."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1ac6376d",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import RecursiveCharacterTextSplitter"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6787b13b",
"metadata": {},
"outputs": [],
"source": [
"text_splitter = RecursiveCharacterTextSplitter(\n",
" # Set a really small chunk size, just to show.\n",
" chunk_size = 100,\n",
" chunk_overlap = 20,\n",
" length_function = len,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4f0e7d9b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet.\n",
"and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n"
]
}
],
"source": [
"texts = text_splitter.split_text(state_of_the_union)\n",
"print(texts[0])\n",
"print(texts[1])"
]
},
{
"cell_type": "markdown",
"id": "87a71115",

@ -15,7 +15,6 @@ class TextSplitter(ABC):
def __init__(
self,
separator: str = "\n\n",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
@ -26,7 +25,6 @@ class TextSplitter(ABC):
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._separator = separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
@ -46,7 +44,7 @@ class TextSplitter(ABC):
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
return documents
def _merge_splits(self, splits: Iterable[str]) -> List[str]:
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
docs = []
@ -61,13 +59,18 @@ class TextSplitter(ABC):
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
docs.append(self._separator.join(current_doc))
while total > self._chunk_overlap:
docs.append(separator.join(current_doc))
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len > self._chunk_size and total > 0
):
total -= self._length_function(current_doc[0])
current_doc = current_doc[1:]
current_doc.append(d)
total += _len
docs.append(self._separator.join(current_doc))
docs.append(separator.join(current_doc))
return docs
@classmethod
@ -116,21 +119,74 @@ class TextSplitter(ABC):
class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = text.split(self._separator)
return self._merge_splits(splits)
if self._separator:
splits = text.split(self._separator)
else:
splits = list(text)
return self._merge_splits(splits, self._separator)
class RecursiveCharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if len(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK."""
def __init__(
self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200
):
def __init__(self, separator: str = "\n\n", **kwargs: Any):
"""Initialize the NLTK splitter."""
super(NLTKTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
super().__init__(**kwargs)
try:
from nltk.tokenize import sent_tokenize
@ -139,26 +195,23 @@ class NLTKTextSplitter(TextSplitter):
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = self._tokenizer(text)
return self._merge_splits(splits)
return self._merge_splits(splits, self._separator)
class SpacyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using Spacy."""
def __init__(
self,
separator: str = "\n\n",
pipeline: str = "en_core_web_sm",
chunk_size: int = 4000,
chunk_overlap: int = 200,
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
):
"""Initialize the spacy text splitter."""
super(SpacyTextSplitter, self).__init__(separator, chunk_size, chunk_overlap)
super.__init__(**kwargs)
try:
import spacy
except ImportError:
@ -166,8 +219,9 @@ class SpacyTextSplitter(TextSplitter):
"Spacy is not installed, please install it with `pip install spacy`."
)
self._tokenizer = spacy.load(pipeline)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = (str(s) for s in self._tokenizer(text).sents)
return self._merge_splits(splits)
return self._merge_splits(splits, self._separator)

@ -2,7 +2,10 @@
import pytest
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.text_splitter import (
CharacterTextSplitter,
RecursiveCharacterTextSplitter,
)
def test_character_text_splitter() -> None:
@ -23,6 +26,15 @@ def test_character_text_splitter_long() -> None:
assert output == expected_output
def test_character_text_splitter_short_words_first() -> None:
"""Test splitting by character count when shorter words are first."""
text = "a a foo bar baz"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["a a", "foo", "bar", "baz"]
assert output == expected_output
def test_character_text_splitter_longer_words() -> None:
"""Test splitting by characters when splits not found easily."""
text = "foo bar baz 123"
@ -62,3 +74,33 @@ def test_create_documents_with_metadata() -> None:
Document(page_content="baz", metadata={"source": "2"}),
]
assert docs == expected_docs
def test_iterative_text_splitter() -> None:
"""Test iterative text splitter."""
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
This is a weird text to write, but gotta test the splittingggg some how.
Bye!\n\n-H."""
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = [
"Hi.",
"I'm",
"Harrison.",
"How? Are?",
"You?",
"Okay then f",
"f f f f.",
"This is a",
"a weird",
"text to",
"write, but",
"gotta test",
"the",
"splitting",
"gggg",
"some how.",
"Bye!\n\n-H.",
]
assert output == expected_output

Loading…
Cancel
Save