Harrison/text splitter (#5417)

adds support for keeping separators around when using recursive text
splitter
This commit is contained in:
Harrison Chase 2023-05-29 16:56:31 -07:00 committed by GitHub
parent cf5803e44c
commit 72f99ff953
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 59 deletions

View File

@ -42,17 +42,17 @@
" \n",
"def foo():\n",
"\n",
"def testing_func():\n",
"def testing_func_with_long_name():\n",
"\n",
"def bar():\n",
"\"\"\"\n",
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)"
"python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6cdc55f3",
"id": "8cc33770",
"metadata": {},
"outputs": [],
"source": [
@ -62,15 +62,16 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "8cc33770",
"id": "f5f70775",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n",
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n",
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]"
"[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
" Document(page_content='def foo():', metadata={}),\n",
" Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
" Document(page_content='def bar():', metadata={})]"
]
},
"execution_count": 4,
@ -82,33 +83,10 @@
"docs"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "de625e08-c440-489d-beed-020b6c53bf69",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"python_splitter.split_text(python_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed",
"id": "6e096d42",
"metadata": {},
"outputs": [],
"source": []
@ -130,7 +108,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.1"
},
"vscode": {
"interpreter": {

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import copy
import logging
import re
from abc import ABC, abstractmethod
from typing import (
AbstractSet,
@ -27,6 +28,23 @@ logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = text.split(separator)
else:
splits = list(text)
return [s for s in splits if s != ""]
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
):
"""Create a new TextSplitter."""
"""Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether or not to keep the separator in the chunks
"""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
self._keep_separator = keep_separator
@abstractmethod
def split_text(self, text: str) -> List[str]:
@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter):
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
if self._separator:
splits = text.split(self._separator)
else:
splits = list(text)
return self._merge_splits(splits, self._separator)
splits = _split_text(text, self._separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)
class TokenTextSplitter(TextSplitter):
@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
that works.
"""
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]:
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
separator = separators[-1]
new_separators = None
for i, _s in enumerate(separators):
if _s == "":
separator = _s
break
if _s in text:
separator = _s
new_separators = separators[i + 1 :]
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
splits = _split_text(text, separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.split_text(s)
if new_separators is None:
final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text)
return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK."""

View File

@ -4,9 +4,23 @@ import pytest
from langchain.docstore.document import Document
from langchain.text_splitter import (
CharacterTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
)
FAKE_PYTHON_TEXT = """
class Foo:
def bar():
def foo():
def testing_func():
def bar():
"""
def test_character_text_splitter() -> None:
"""Test splitting by character count."""
@ -135,15 +149,16 @@ Bye!\n\n-H."""
"Okay then",
"f f f f.",
"This is a",
"a weird",
"weird",
"text to",
"write, but",
"gotta test",
"the",
"splittingg",
"ggg",
"write,",
"but gotta",
"test the",
"splitting",
"gggg",
"some how.",
"Bye!\n\n-H.",
"Bye!",
"-H.",
]
assert output == expected_output
@ -168,3 +183,14 @@ def test_split_documents() -> None:
Document(page_content="z", metadata={"source": "1"}),
]
assert splitter.split_documents(docs) == expected_output
def test_python_text_splitter() -> None:
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
splits = splitter.split_text(FAKE_PYTHON_TEXT)
split_0 = """class Foo:\n\n def bar():"""
split_1 = """def foo():"""
split_2 = """def testing_func():"""
split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits