forked from Archives/langchain
Harrison/text splitter (#5417)
adds support for keeping separators around when using recursive text splitter
This commit is contained in:
parent
cf5803e44c
commit
72f99ff953
@ -42,17 +42,17 @@
|
||||
" \n",
|
||||
"def foo():\n",
|
||||
"\n",
|
||||
"def testing_func():\n",
|
||||
"def testing_func_with_long_name():\n",
|
||||
"\n",
|
||||
"def bar():\n",
|
||||
"\"\"\"\n",
|
||||
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)"
|
||||
"python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "6cdc55f3",
|
||||
"id": "8cc33770",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@ -62,15 +62,16 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "8cc33770",
|
||||
"id": "f5f70775",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n",
|
||||
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n",
|
||||
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]"
|
||||
"[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
|
||||
" Document(page_content='def foo():', metadata={}),\n",
|
||||
" Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
|
||||
" Document(page_content='def bar():', metadata={})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
@ -82,33 +83,10 @@
|
||||
"docs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "de625e08-c440-489d-beed-020b6c53bf69",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"python_splitter.split_text(python_text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed",
|
||||
"id": "6e096d42",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
@ -130,7 +108,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.9.1"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
@ -27,6 +28,23 @@ logger = logging.getLogger(__name__)
|
||||
TS = TypeVar("TS", bound="TextSplitter")
|
||||
|
||||
|
||||
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
if keep_separator:
|
||||
# The parentheses in the pattern keep the delimiters in the result.
|
||||
_splits = re.split(f"({separator})", text)
|
||||
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||
if len(_splits) % 2 == 0:
|
||||
splits += _splits[-1:]
|
||||
splits = [_splits[0]] + splits
|
||||
else:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
|
||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
Args:
|
||||
chunk_size: Maximum size of chunks to return
|
||||
chunk_overlap: Overlap in characters between chunks
|
||||
length_function: Function that measures the length of given chunks
|
||||
keep_separator: Whether or not to keep the separator in the chunks
|
||||
"""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||
@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
self._keep_separator = keep_separator
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter):
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
if self._separator:
|
||||
splits = text.split(self._separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
return self._merge_splits(splits, self._separator)
|
||||
splits = _split_text(text, self._separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
that works.
|
||||
"""
|
||||
|
||||
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
def __init__(
|
||||
self,
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = self._separators[-1]
|
||||
for _s in self._separators:
|
||||
separator = separators[-1]
|
||||
new_separators = None
|
||||
for i, _s in enumerate(separators):
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if _s in text:
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
# Now that we have the separator, split the text
|
||||
if separator:
|
||||
splits = text.split(separator)
|
||||
else:
|
||||
splits = list(text)
|
||||
|
||||
splits = _split_text(text, separator, self._keep_separator)
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
for s in splits:
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
else:
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
other_info = self.split_text(s)
|
||||
if new_separators is None:
|
||||
final_chunks.append(s)
|
||||
else:
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
return self._split_text(text, self._separators)
|
||||
|
||||
|
||||
class NLTKTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||
|
@ -4,9 +4,23 @@ import pytest
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import (
|
||||
CharacterTextSplitter,
|
||||
PythonCodeTextSplitter,
|
||||
RecursiveCharacterTextSplitter,
|
||||
)
|
||||
|
||||
FAKE_PYTHON_TEXT = """
|
||||
class Foo:
|
||||
|
||||
def bar():
|
||||
|
||||
|
||||
def foo():
|
||||
|
||||
def testing_func():
|
||||
|
||||
def bar():
|
||||
"""
|
||||
|
||||
|
||||
def test_character_text_splitter() -> None:
|
||||
"""Test splitting by character count."""
|
||||
@ -135,15 +149,16 @@ Bye!\n\n-H."""
|
||||
"Okay then",
|
||||
"f f f f.",
|
||||
"This is a",
|
||||
"a weird",
|
||||
"weird",
|
||||
"text to",
|
||||
"write, but",
|
||||
"gotta test",
|
||||
"the",
|
||||
"splittingg",
|
||||
"ggg",
|
||||
"write,",
|
||||
"but gotta",
|
||||
"test the",
|
||||
"splitting",
|
||||
"gggg",
|
||||
"some how.",
|
||||
"Bye!\n\n-H.",
|
||||
"Bye!",
|
||||
"-H.",
|
||||
]
|
||||
assert output == expected_output
|
||||
|
||||
@ -168,3 +183,14 @@ def test_split_documents() -> None:
|
||||
Document(page_content="z", metadata={"source": "1"}),
|
||||
]
|
||||
assert splitter.split_documents(docs) == expected_output
|
||||
|
||||
|
||||
def test_python_text_splitter() -> None:
|
||||
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
|
||||
splits = splitter.split_text(FAKE_PYTHON_TEXT)
|
||||
split_0 = """class Foo:\n\n def bar():"""
|
||||
split_1 = """def foo():"""
|
||||
split_2 = """def testing_func():"""
|
||||
split_3 = """def bar():"""
|
||||
expected_splits = [split_0, split_1, split_2, split_3]
|
||||
assert splits == expected_splits
|
||||
|
Loading…
Reference in New Issue
Block a user