Add regex control over separators in character text splitter (#7933)

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
#7854

Added the ability to use the `separator` ase a regex or a simple
character.
Fixed a bug where `start_index` was incorrectly counting from -1.

Who can review?
@eyurtsev
@hwchase17 
@mmz-001
pull/8738/head
Ilya 1 year ago committed by GitHub
parent e68a1d73d0
commit 6f0bccfeb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ text_splitter = CharacterTextSplitter(
chunk_size = 1000,
chunk_overlap = 200,
length_function = len,
is_separator_regex = False,
)
```

@ -16,6 +16,7 @@ text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 100,
chunk_overlap = 20,
length_function = len,
is_separator_regex = False,
)
```

@ -281,15 +281,21 @@ class TextSplitter(BaseDocumentTransformer, ABC):
class CharacterTextSplitter(TextSplitter):
"""Splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
def __init__(
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
self._is_separator_regex = is_separator_regex
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
separator = (
self._separator if self._is_separator_regex else re.escape(self._separator)
)
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = "" if self._keep_separator else self._separator
return self._merge_splits(splits, _separator)
@ -629,11 +635,13 @@ class RecursiveCharacterTextSplitter(TextSplitter):
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = False,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
self._is_separator_regex = is_separator_regex
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
@ -642,15 +650,18 @@ class RecursiveCharacterTextSplitter(TextSplitter):
separator = separators[-1]
new_separators = []
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_s, text):
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1 :]
break
splits = _split_text_with_regex(text, separator, self._keep_separator)
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
@ -680,7 +691,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
cls, language: Language, **kwargs: Any
) -> RecursiveCharacterTextSplitter:
separators = cls.get_separators_for_language(language)
return cls(separators=separators, **kwargs)
return cls(separators=separators, is_separator_regex=True, **kwargs)
@staticmethod
def get_separators_for_language(language: Language) -> List[str]:

@ -1,4 +1,5 @@
"""Test text splitting functionality."""
import re
from typing import List
import pytest
@ -80,25 +81,43 @@ def test_character_text_splitter_longer_words() -> None:
assert output == expected_output
def test_character_text_splitter_keep_separator_regex() -> None:
@pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
)
def test_character_text_splitter_keep_separator_regex(
separator: str, is_separator_regex: bool
) -> None:
"""Test splitting by characters while keeping the separator
that is a regex special character.
"""
text = "foo.bar.baz.123"
splitter = CharacterTextSplitter(
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True
separator=separator,
chunk_size=1,
chunk_overlap=0,
keep_separator=True,
is_separator_regex=is_separator_regex,
)
output = splitter.split_text(text)
expected_output = ["foo", ".bar", ".baz", ".123"]
assert output == expected_output
def test_character_text_splitter_discard_separator_regex() -> None:
@pytest.mark.parametrize(
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
)
def test_character_text_splitter_discard_separator_regex(
separator: str, is_separator_regex: bool
) -> None:
"""Test splitting by characters discarding the separator
that is a regex special character."""
text = "foo.bar.baz.123"
splitter = CharacterTextSplitter(
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False
separator=separator,
chunk_size=1,
chunk_overlap=0,
keep_separator=False,
is_separator_regex=is_separator_regex,
)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "123"]

Loading…
Cancel
Save