Harrison/text splitter (#5417)

adds support for keeping separators around when using recursive text
splitter
This commit is contained in:
Harrison Chase 2023-05-29 16:56:31 -07:00 committed by GitHub
parent cf5803e44c
commit 72f99ff953
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 59 deletions

View File

@ -42,17 +42,17 @@
" \n", " \n",
"def foo():\n", "def foo():\n",
"\n", "\n",
"def testing_func():\n", "def testing_func_with_long_name():\n",
"\n", "\n",
"def bar():\n", "def bar():\n",
"\"\"\"\n", "\"\"\"\n",
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)" "python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "6cdc55f3", "id": "8cc33770",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -62,15 +62,16 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"id": "8cc33770", "id": "f5f70775",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n", "[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n", " Document(page_content='def foo():', metadata={}),\n",
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]" " Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
" Document(page_content='def bar():', metadata={})]"
] ]
}, },
"execution_count": 4, "execution_count": 4,
@ -82,33 +83,10 @@
"docs" "docs"
] ]
}, },
{
"cell_type": "code",
"execution_count": 3,
"id": "de625e08-c440-489d-beed-020b6c53bf69",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"python_splitter.split_text(python_text)"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed", "id": "6e096d42",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [] "source": []
@ -130,7 +108,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.6" "version": "3.9.1"
}, },
"vscode": { "vscode": {
"interpreter": { "interpreter": {

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import copy import copy
import logging import logging
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import ( from typing import (
AbstractSet, AbstractSet,
@ -27,6 +28,23 @@ logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter") TS = TypeVar("TS", bound="TextSplitter")
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = [_splits[0]] + splits
else:
splits = text.split(separator)
else:
splits = list(text)
return [s for s in splits if s != ""]
class TextSplitter(BaseDocumentTransformer, ABC): class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks.""" """Interface for splitting text into chunks."""
@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
chunk_size: int = 4000, chunk_size: int = 4000,
chunk_overlap: int = 200, chunk_overlap: int = 200,
length_function: Callable[[str], int] = len, length_function: Callable[[str], int] = len,
keep_separator: bool = False,
): ):
"""Create a new TextSplitter.""" """Create a new TextSplitter.
Args:
chunk_size: Maximum size of chunks to return
chunk_overlap: Overlap in characters between chunks
length_function: Function that measures the length of given chunks
keep_separator: Whether or not to keep the separator in the chunks
"""
if chunk_overlap > chunk_size: if chunk_overlap > chunk_size:
raise ValueError( raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
self._chunk_size = chunk_size self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap self._chunk_overlap = chunk_overlap
self._length_function = length_function self._length_function = length_function
self._keep_separator = keep_separator
@abstractmethod @abstractmethod
def split_text(self, text: str) -> List[str]: def split_text(self, text: str) -> List[str]:
@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter):
def split_text(self, text: str) -> List[str]: def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones. # First we naively split the large input into a bunch of smaller ones.
if self._separator: splits = _split_text(text, self._separator, self._keep_separator)
splits = text.split(self._separator) _separator = "" if self._keep_separator else self._separator
else: return self._merge_splits(splits, _separator)
splits = list(text)
return self._merge_splits(splits, self._separator)
class TokenTextSplitter(TextSplitter): class TokenTextSplitter(TextSplitter):
@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
that works. that works.
""" """
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
super().__init__(**kwargs) super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""] self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]: def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks."""
final_chunks = [] final_chunks = []
# Get appropriate separator to use # Get appropriate separator to use
separator = self._separators[-1] separator = separators[-1]
for _s in self._separators: new_separators = None
for i, _s in enumerate(separators):
if _s == "": if _s == "":
separator = _s separator = _s
break break
if _s in text: if _s in text:
separator = _s separator = _s
new_separators = separators[i + 1 :]
break break
# Now that we have the separator, split the text
if separator: splits = _split_text(text, separator, self._keep_separator)
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts. # Now go merging things, recursively splitting longer texts.
_good_splits = [] _good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits: for s in splits:
if self._length_function(s) < self._chunk_size: if self._length_function(s) < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)
else: else:
if _good_splits: if _good_splits:
merged_text = self._merge_splits(_good_splits, separator) merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
_good_splits = [] _good_splits = []
other_info = self.split_text(s) if new_separators is None:
final_chunks.extend(other_info) final_chunks.append(s)
else:
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits: if _good_splits:
merged_text = self._merge_splits(_good_splits, separator) merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
return final_chunks return final_chunks
def split_text(self, text: str) -> List[str]:
return self._split_text(text, self._separators)
class NLTKTextSplitter(TextSplitter): class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK.""" """Implementation of splitting text that looks at sentences using NLTK."""

View File

@ -4,9 +4,23 @@ import pytest
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import ( from langchain.text_splitter import (
CharacterTextSplitter, CharacterTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter, RecursiveCharacterTextSplitter,
) )
FAKE_PYTHON_TEXT = """
class Foo:
def bar():
def foo():
def testing_func():
def bar():
"""
def test_character_text_splitter() -> None: def test_character_text_splitter() -> None:
"""Test splitting by character count.""" """Test splitting by character count."""
@ -135,15 +149,16 @@ Bye!\n\n-H."""
"Okay then", "Okay then",
"f f f f.", "f f f f.",
"This is a", "This is a",
"a weird", "weird",
"text to", "text to",
"write, but", "write,",
"gotta test", "but gotta",
"the", "test the",
"splittingg", "splitting",
"ggg", "gggg",
"some how.", "some how.",
"Bye!\n\n-H.", "Bye!",
"-H.",
] ]
assert output == expected_output assert output == expected_output
@ -168,3 +183,14 @@ def test_split_documents() -> None:
Document(page_content="z", metadata={"source": "1"}), Document(page_content="z", metadata={"source": "1"}),
] ]
assert splitter.split_documents(docs) == expected_output assert splitter.split_documents(docs) == expected_output
def test_python_text_splitter() -> None:
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
splits = splitter.split_text(FAKE_PYTHON_TEXT)
split_0 = """class Foo:\n\n def bar():"""
split_1 = """def foo():"""
split_2 = """def testing_func():"""
split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits