forked from Archives/langchain
Harrison/text splitter (#5417)
adds support for keeping separators around when using recursive text splitter
This commit is contained in:
parent
cf5803e44c
commit
72f99ff953
@ -42,17 +42,17 @@
|
|||||||
" \n",
|
" \n",
|
||||||
"def foo():\n",
|
"def foo():\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def testing_func():\n",
|
"def testing_func_with_long_name():\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def bar():\n",
|
"def bar():\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
"python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)"
|
"python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"id": "6cdc55f3",
|
"id": "8cc33770",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@ -62,15 +62,16 @@
|
|||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
"id": "8cc33770",
|
"id": "f5f70775",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n",
|
"[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n",
|
||||||
" Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n",
|
" Document(page_content='def foo():', metadata={}),\n",
|
||||||
" Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]"
|
" Document(page_content='def testing_func_with_long_name():', metadata={}),\n",
|
||||||
|
" Document(page_content='def bar():', metadata={})]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 4,
|
"execution_count": 4,
|
||||||
@ -82,33 +83,10 @@
|
|||||||
"docs"
|
"docs"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "de625e08-c440-489d-beed-020b6c53bf69",
|
|
||||||
"metadata": {
|
|
||||||
"tags": []
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"text/plain": [
|
|
||||||
"['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"execution_count": 3,
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "execute_result"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"python_splitter.split_text(python_text)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "55aadd84-75ca-48ae-9b84-b39c368488ed",
|
"id": "6e096d42",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
@ -130,7 +108,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.6"
|
"version": "3.9.1"
|
||||||
},
|
},
|
||||||
"vscode": {
|
"vscode": {
|
||||||
"interpreter": {
|
"interpreter": {
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import (
|
from typing import (
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
@ -27,6 +28,23 @@ logger = logging.getLogger(__name__)
|
|||||||
TS = TypeVar("TS", bound="TextSplitter")
|
TS = TypeVar("TS", bound="TextSplitter")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]:
|
||||||
|
# Now that we have the separator, split the text
|
||||||
|
if separator:
|
||||||
|
if keep_separator:
|
||||||
|
# The parentheses in the pattern keep the delimiters in the result.
|
||||||
|
_splits = re.split(f"({separator})", text)
|
||||||
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
||||||
|
if len(_splits) % 2 == 0:
|
||||||
|
splits += _splits[-1:]
|
||||||
|
splits = [_splits[0]] + splits
|
||||||
|
else:
|
||||||
|
splits = text.split(separator)
|
||||||
|
else:
|
||||||
|
splits = list(text)
|
||||||
|
return [s for s in splits if s != ""]
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||||
"""Interface for splitting text into chunks."""
|
"""Interface for splitting text into chunks."""
|
||||||
|
|
||||||
@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
chunk_size: int = 4000,
|
chunk_size: int = 4000,
|
||||||
chunk_overlap: int = 200,
|
chunk_overlap: int = 200,
|
||||||
length_function: Callable[[str], int] = len,
|
length_function: Callable[[str], int] = len,
|
||||||
|
keep_separator: bool = False,
|
||||||
):
|
):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size: Maximum size of chunks to return
|
||||||
|
chunk_overlap: Overlap in characters between chunks
|
||||||
|
length_function: Function that measures the length of given chunks
|
||||||
|
keep_separator: Whether or not to keep the separator in the chunks
|
||||||
|
"""
|
||||||
if chunk_overlap > chunk_size:
|
if chunk_overlap > chunk_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
||||||
@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
self._chunk_size = chunk_size
|
self._chunk_size = chunk_size
|
||||||
self._chunk_overlap = chunk_overlap
|
self._chunk_overlap = chunk_overlap
|
||||||
self._length_function = length_function
|
self._length_function = length_function
|
||||||
|
self._keep_separator = keep_separator
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter):
|
|||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
# First we naively split the large input into a bunch of smaller ones.
|
||||||
if self._separator:
|
splits = _split_text(text, self._separator, self._keep_separator)
|
||||||
splits = text.split(self._separator)
|
_separator = "" if self._keep_separator else self._separator
|
||||||
else:
|
return self._merge_splits(splits, _separator)
|
||||||
splits = list(text)
|
|
||||||
return self._merge_splits(splits, self._separator)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenTextSplitter(TextSplitter):
|
class TokenTextSplitter(TextSplitter):
|
||||||
@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
that works.
|
that works.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
def __init__(
|
||||||
|
self,
|
||||||
|
separators: Optional[List[str]] = None,
|
||||||
|
keep_separator: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
final_chunks = []
|
final_chunks = []
|
||||||
# Get appropriate separator to use
|
# Get appropriate separator to use
|
||||||
separator = self._separators[-1]
|
separator = separators[-1]
|
||||||
for _s in self._separators:
|
new_separators = None
|
||||||
|
for i, _s in enumerate(separators):
|
||||||
if _s == "":
|
if _s == "":
|
||||||
separator = _s
|
separator = _s
|
||||||
break
|
break
|
||||||
if _s in text:
|
if _s in text:
|
||||||
separator = _s
|
separator = _s
|
||||||
|
new_separators = separators[i + 1 :]
|
||||||
break
|
break
|
||||||
# Now that we have the separator, split the text
|
|
||||||
if separator:
|
splits = _split_text(text, separator, self._keep_separator)
|
||||||
splits = text.split(separator)
|
|
||||||
else:
|
|
||||||
splits = list(text)
|
|
||||||
# Now go merging things, recursively splitting longer texts.
|
# Now go merging things, recursively splitting longer texts.
|
||||||
_good_splits = []
|
_good_splits = []
|
||||||
|
_separator = "" if self._keep_separator else separator
|
||||||
for s in splits:
|
for s in splits:
|
||||||
if self._length_function(s) < self._chunk_size:
|
if self._length_function(s) < self._chunk_size:
|
||||||
_good_splits.append(s)
|
_good_splits.append(s)
|
||||||
else:
|
else:
|
||||||
if _good_splits:
|
if _good_splits:
|
||||||
merged_text = self._merge_splits(_good_splits, separator)
|
merged_text = self._merge_splits(_good_splits, _separator)
|
||||||
final_chunks.extend(merged_text)
|
final_chunks.extend(merged_text)
|
||||||
_good_splits = []
|
_good_splits = []
|
||||||
other_info = self.split_text(s)
|
if new_separators is None:
|
||||||
final_chunks.extend(other_info)
|
final_chunks.append(s)
|
||||||
|
else:
|
||||||
|
other_info = self._split_text(s, new_separators)
|
||||||
|
final_chunks.extend(other_info)
|
||||||
if _good_splits:
|
if _good_splits:
|
||||||
merged_text = self._merge_splits(_good_splits, separator)
|
merged_text = self._merge_splits(_good_splits, _separator)
|
||||||
final_chunks.extend(merged_text)
|
final_chunks.extend(merged_text)
|
||||||
return final_chunks
|
return final_chunks
|
||||||
|
|
||||||
|
def split_text(self, text: str) -> List[str]:
|
||||||
|
return self._split_text(text, self._separators)
|
||||||
|
|
||||||
|
|
||||||
class NLTKTextSplitter(TextSplitter):
|
class NLTKTextSplitter(TextSplitter):
|
||||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||||
|
@ -4,9 +4,23 @@ import pytest
|
|||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.text_splitter import (
|
from langchain.text_splitter import (
|
||||||
CharacterTextSplitter,
|
CharacterTextSplitter,
|
||||||
|
PythonCodeTextSplitter,
|
||||||
RecursiveCharacterTextSplitter,
|
RecursiveCharacterTextSplitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
FAKE_PYTHON_TEXT = """
|
||||||
|
class Foo:
|
||||||
|
|
||||||
|
def bar():
|
||||||
|
|
||||||
|
|
||||||
|
def foo():
|
||||||
|
|
||||||
|
def testing_func():
|
||||||
|
|
||||||
|
def bar():
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter() -> None:
|
def test_character_text_splitter() -> None:
|
||||||
"""Test splitting by character count."""
|
"""Test splitting by character count."""
|
||||||
@ -135,15 +149,16 @@ Bye!\n\n-H."""
|
|||||||
"Okay then",
|
"Okay then",
|
||||||
"f f f f.",
|
"f f f f.",
|
||||||
"This is a",
|
"This is a",
|
||||||
"a weird",
|
"weird",
|
||||||
"text to",
|
"text to",
|
||||||
"write, but",
|
"write,",
|
||||||
"gotta test",
|
"but gotta",
|
||||||
"the",
|
"test the",
|
||||||
"splittingg",
|
"splitting",
|
||||||
"ggg",
|
"gggg",
|
||||||
"some how.",
|
"some how.",
|
||||||
"Bye!\n\n-H.",
|
"Bye!",
|
||||||
|
"-H.",
|
||||||
]
|
]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
@ -168,3 +183,14 @@ def test_split_documents() -> None:
|
|||||||
Document(page_content="z", metadata={"source": "1"}),
|
Document(page_content="z", metadata={"source": "1"}),
|
||||||
]
|
]
|
||||||
assert splitter.split_documents(docs) == expected_output
|
assert splitter.split_documents(docs) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_text_splitter() -> None:
|
||||||
|
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
|
||||||
|
splits = splitter.split_text(FAKE_PYTHON_TEXT)
|
||||||
|
split_0 = """class Foo:\n\n def bar():"""
|
||||||
|
split_1 = """def foo():"""
|
||||||
|
split_2 = """def testing_func():"""
|
||||||
|
split_3 = """def bar():"""
|
||||||
|
expected_splits = [split_0, split_1, split_2, split_3]
|
||||||
|
assert splits == expected_splits
|
||||||
|
Loading…
Reference in New Issue
Block a user