From 72f99ff953d826a4a06d002c7c7acef8e37e0baf Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 29 May 2023 16:56:31 -0700 Subject: [PATCH] Harrison/text splitter (#5417) adds support for keeping separators around when using recursive text splitter --- .../text_splitters/examples/python.ipynb | 42 +++------- langchain/text_splitter.py | 76 ++++++++++++++----- tests/unit_tests/test_text_splitter.py | 40 ++++++++-- 3 files changed, 99 insertions(+), 59 deletions(-) diff --git a/docs/modules/indexes/text_splitters/examples/python.ipynb b/docs/modules/indexes/text_splitters/examples/python.ipynb index a184bcd5eb..ae48b7f191 100644 --- a/docs/modules/indexes/text_splitters/examples/python.ipynb +++ b/docs/modules/indexes/text_splitters/examples/python.ipynb @@ -42,17 +42,17 @@ " \n", "def foo():\n", "\n", - "def testing_func():\n", + "def testing_func_with_long_name():\n", "\n", "def bar():\n", "\"\"\"\n", - "python_splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)" + "python_splitter = PythonCodeTextSplitter(chunk_size=40, chunk_overlap=0)" ] }, { "cell_type": "code", "execution_count": 3, - "id": "6cdc55f3", + "id": "8cc33770", "metadata": {}, "outputs": [], "source": [ @@ -62,15 +62,16 @@ { "cell_type": "code", "execution_count": 4, - "id": "8cc33770", + "id": "f5f70775", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[Document(page_content='Foo:\\n\\n def bar():', lookup_str='', metadata={}, lookup_index=0),\n", - " Document(page_content='foo():\\n\\ndef testing_func():', lookup_str='', metadata={}, lookup_index=0),\n", - " Document(page_content='bar():', lookup_str='', metadata={}, lookup_index=0)]" + "[Document(page_content='class Foo:\\n\\n def bar():', metadata={}),\n", + " Document(page_content='def foo():', metadata={}),\n", + " Document(page_content='def testing_func_with_long_name():', metadata={}),\n", + " Document(page_content='def bar():', metadata={})]" ] }, "execution_count": 4, @@ -82,33 +83,10 @@ "docs" ] }, - { - "cell_type": "code", - "execution_count": 3, - "id": "de625e08-c440-489d-beed-020b6c53bf69", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['Foo:\\n\\n def bar():', 'foo():\\n\\ndef testing_func():', 'bar():']" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "python_splitter.split_text(python_text)" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "55aadd84-75ca-48ae-9b84-b39c368488ed", + "id": "6e096d42", "metadata": {}, "outputs": [], "source": [] @@ -130,7 +108,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.1" }, "vscode": { "interpreter": { diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 5a6a69ab10..38429f6c5a 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy import logging +import re from abc import ABC, abstractmethod from typing import ( AbstractSet, @@ -27,6 +28,23 @@ logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") +def _split_text(text: str, separator: str, keep_separator: bool) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = text.split(separator) + else: + splits = list(text) + return [s for s in splits if s != ""] + + class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" @@ -35,8 +53,16 @@ class TextSplitter(BaseDocumentTransformer, ABC): chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, + keep_separator: bool = False, ): - """Create a new TextSplitter.""" + """Create a new TextSplitter. + + Args: + chunk_size: Maximum size of chunks to return + chunk_overlap: Overlap in characters between chunks + length_function: Function that measures the length of given chunks + keep_separator: Whether or not to keep the separator in the chunks + """ if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " @@ -45,6 +71,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function + self._keep_separator = keep_separator @abstractmethod def split_text(self, text: str) -> List[str]: @@ -211,11 +238,9 @@ class CharacterTextSplitter(TextSplitter): def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. - if self._separator: - splits = text.split(self._separator) - else: - splits = list(text) - return self._merge_splits(splits, self._separator) + splits = _split_text(text, self._separator, self._keep_separator) + _separator = "" if self._keep_separator else self._separator + return self._merge_splits(splits, _separator) class TokenTextSplitter(TextSplitter): @@ -274,45 +299,56 @@ class RecursiveCharacterTextSplitter(TextSplitter): that works. """ - def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + **kwargs: Any, + ): """Create a new TextSplitter.""" - super().__init__(**kwargs) + super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] - def split_text(self, text: str) -> List[str]: + def _split_text(self, text: str, separators: List[str]) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use - separator = self._separators[-1] - for _s in self._separators: + separator = separators[-1] + new_separators = None + for i, _s in enumerate(separators): if _s == "": separator = _s break if _s in text: separator = _s + new_separators = separators[i + 1 :] break - # Now that we have the separator, split the text - if separator: - splits = text.split(separator) - else: - splits = list(text) + + splits = _split_text(text, separator, self._keep_separator) # Now go merging things, recursively splitting longer texts. _good_splits = [] + _separator = "" if self._keep_separator else separator for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] - other_info = self.split_text(s) - final_chunks.extend(other_info) + if new_separators is None: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) if _good_splits: - merged_text = self._merge_splits(_good_splits, separator) + merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) return final_chunks + def split_text(self, text: str) -> List[str]: + return self._split_text(text, self._separators) + class NLTKTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using NLTK.""" diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 31736a9155..75f243b914 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -4,9 +4,23 @@ import pytest from langchain.docstore.document import Document from langchain.text_splitter import ( CharacterTextSplitter, + PythonCodeTextSplitter, RecursiveCharacterTextSplitter, ) +FAKE_PYTHON_TEXT = """ +class Foo: + + def bar(): + + +def foo(): + +def testing_func(): + +def bar(): +""" + def test_character_text_splitter() -> None: """Test splitting by character count.""" @@ -135,15 +149,16 @@ Bye!\n\n-H.""" "Okay then", "f f f f.", "This is a", - "a weird", + "weird", "text to", - "write, but", - "gotta test", - "the", - "splittingg", - "ggg", + "write,", + "but gotta", + "test the", + "splitting", + "gggg", "some how.", - "Bye!\n\n-H.", + "Bye!", + "-H.", ] assert output == expected_output @@ -168,3 +183,14 @@ def test_split_documents() -> None: Document(page_content="z", metadata={"source": "1"}), ] assert splitter.split_documents(docs) == expected_output + + +def test_python_text_splitter() -> None: + splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0) + splits = splitter.split_text(FAKE_PYTHON_TEXT) + split_0 = """class Foo:\n\n def bar():""" + split_1 = """def foo():""" + split_2 = """def testing_func():""" + split_3 = """def bar():""" + expected_splits = [split_0, split_1, split_2, split_3] + assert splits == expected_splits