From 6f0bccfeb584decae6651905a688409e56f65863 Mon Sep 17 00:00:00 2001 From: Ilya Date: Fri, 4 Aug 2023 06:25:23 +0300 Subject: [PATCH] Add regex control over separators in character text splitter (#7933) #7854 Added the ability to use the `separator` ase a regex or a simple character. Fixed a bug where `start_index` was incorrectly counting from -1. Who can review? @eyurtsev @hwchase17 @mmz-001 --- .../character_text_splitter.mdx | 1 + .../recursive_text_splitter.mdx | 1 + libs/langchain/langchain/text_splitter.py | 21 +++++++++++---- .../tests/unit_tests/test_text_splitter.py | 27 ++++++++++++++++--- 4 files changed, 41 insertions(+), 9 deletions(-) diff --git a/docs/snippets/modules/data_connection/document_transformers/text_splitters/character_text_splitter.mdx b/docs/snippets/modules/data_connection/document_transformers/text_splitters/character_text_splitter.mdx index e85f389845..419eb8be2f 100644 --- a/docs/snippets/modules/data_connection/document_transformers/text_splitters/character_text_splitter.mdx +++ b/docs/snippets/modules/data_connection/document_transformers/text_splitters/character_text_splitter.mdx @@ -12,6 +12,7 @@ text_splitter = CharacterTextSplitter( chunk_size = 1000, chunk_overlap = 200, length_function = len, + is_separator_regex = False, ) ``` diff --git a/docs/snippets/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter.mdx b/docs/snippets/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter.mdx index b7a3b41665..4e90fcffe2 100644 --- a/docs/snippets/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter.mdx +++ b/docs/snippets/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter.mdx @@ -16,6 +16,7 @@ text_splitter = RecursiveCharacterTextSplitter( chunk_size = 100, chunk_overlap = 20, length_function = len, + is_separator_regex = False, ) ``` diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index be57522f0f..bd0c560411 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -281,15 +281,21 @@ class TextSplitter(BaseDocumentTransformer, ABC): class CharacterTextSplitter(TextSplitter): """Splitting text that looks at characters.""" - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: + def __init__( + self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any + ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator + self._is_separator_regex = is_separator_regex def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. - splits = _split_text_with_regex(text, self._separator, self._keep_separator) + separator = ( + self._separator if self._is_separator_regex else re.escape(self._separator) + ) + splits = _split_text_with_regex(text, separator, self._keep_separator) _separator = "" if self._keep_separator else self._separator return self._merge_splits(splits, _separator) @@ -629,11 +635,13 @@ class RecursiveCharacterTextSplitter(TextSplitter): self, separators: Optional[List[str]] = None, keep_separator: bool = True, + is_separator_regex: bool = False, **kwargs: Any, ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex def _split_text(self, text: str, separators: List[str]) -> List[str]: """Split incoming text and return chunks.""" @@ -642,15 +650,18 @@ class RecursiveCharacterTextSplitter(TextSplitter): separator = separators[-1] new_separators = [] for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) if _s == "": separator = _s break - if re.search(_s, text): + if re.search(_separator, text): separator = _s new_separators = separators[i + 1 :] break - splits = _split_text_with_regex(text, separator, self._keep_separator) + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + # Now go merging things, recursively splitting longer texts. _good_splits = [] _separator = "" if self._keep_separator else separator @@ -680,7 +691,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): cls, language: Language, **kwargs: Any ) -> RecursiveCharacterTextSplitter: separators = cls.get_separators_for_language(language) - return cls(separators=separators, **kwargs) + return cls(separators=separators, is_separator_regex=True, **kwargs) @staticmethod def get_separators_for_language(language: Language) -> List[str]: diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index e8dee47ffb..59a34d63c7 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -1,4 +1,5 @@ """Test text splitting functionality.""" +import re from typing import List import pytest @@ -80,25 +81,43 @@ def test_character_text_splitter_longer_words() -> None: assert output == expected_output -def test_character_text_splitter_keep_separator_regex() -> None: +@pytest.mark.parametrize( + "separator, is_separator_regex", [(re.escape("."), True), (".", False)] +) +def test_character_text_splitter_keep_separator_regex( + separator: str, is_separator_regex: bool +) -> None: """Test splitting by characters while keeping the separator that is a regex special character. """ text = "foo.bar.baz.123" splitter = CharacterTextSplitter( - separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True + separator=separator, + chunk_size=1, + chunk_overlap=0, + keep_separator=True, + is_separator_regex=is_separator_regex, ) output = splitter.split_text(text) expected_output = ["foo", ".bar", ".baz", ".123"] assert output == expected_output -def test_character_text_splitter_discard_separator_regex() -> None: +@pytest.mark.parametrize( + "separator, is_separator_regex", [(re.escape("."), True), (".", False)] +) +def test_character_text_splitter_discard_separator_regex( + separator: str, is_separator_regex: bool +) -> None: """Test splitting by characters discarding the separator that is a regex special character.""" text = "foo.bar.baz.123" splitter = CharacterTextSplitter( - separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False + separator=separator, + chunk_size=1, + chunk_overlap=0, + keep_separator=False, + is_separator_regex=is_separator_regex, ) output = splitter.split_text(text) expected_output = ["foo", "bar", "baz", "123"]