mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add regex control over separators in character text splitter (#7933)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #7854 Added the ability to use the `separator` ase a regex or a simple character. Fixed a bug where `start_index` was incorrectly counting from -1. Who can review? @eyurtsev @hwchase17 @mmz-001
This commit is contained in:
parent
e68a1d73d0
commit
6f0bccfeb5
@ -12,6 +12,7 @@ text_splitter = CharacterTextSplitter(
|
||||
chunk_size = 1000,
|
||||
chunk_overlap = 200,
|
||||
length_function = len,
|
||||
is_separator_regex = False,
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -16,6 +16,7 @@ text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size = 100,
|
||||
chunk_overlap = 20,
|
||||
length_function = len,
|
||||
is_separator_regex = False,
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -281,15 +281,21 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Splitting text that looks at characters."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
self._is_separator_regex = is_separator_regex
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
||||
separator = (
|
||||
self._separator if self._is_separator_regex else re.escape(self._separator)
|
||||
)
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
_separator = "" if self._keep_separator else self._separator
|
||||
return self._merge_splits(splits, _separator)
|
||||
|
||||
@ -629,11 +635,13 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
self,
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
is_separator_regex: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
self._is_separator_regex = is_separator_regex
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
@ -642,15 +650,18 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
for i, _s in enumerate(separators):
|
||||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||||
if _s == "":
|
||||
separator = _s
|
||||
break
|
||||
if re.search(_s, text):
|
||||
if re.search(_separator, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
break
|
||||
|
||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||
_separator = separator if self._is_separator_regex else re.escape(separator)
|
||||
splits = _split_text_with_regex(text, _separator, self._keep_separator)
|
||||
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
@ -680,7 +691,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
cls, language: Language, **kwargs: Any
|
||||
) -> RecursiveCharacterTextSplitter:
|
||||
separators = cls.get_separators_for_language(language)
|
||||
return cls(separators=separators, **kwargs)
|
||||
return cls(separators=separators, is_separator_regex=True, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_separators_for_language(language: Language) -> List[str]:
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Test text splitting functionality."""
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
@ -80,25 +81,43 @@ def test_character_text_splitter_longer_words() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_keep_separator_regex() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_keep_separator_regex(
|
||||
separator: str, is_separator_regex: bool
|
||||
) -> None:
|
||||
"""Test splitting by characters while keeping the separator
|
||||
that is a regex special character.
|
||||
"""
|
||||
text = "foo.bar.baz.123"
|
||||
splitter = CharacterTextSplitter(
|
||||
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True
|
||||
separator=separator,
|
||||
chunk_size=1,
|
||||
chunk_overlap=0,
|
||||
keep_separator=True,
|
||||
is_separator_regex=is_separator_regex,
|
||||
)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo", ".bar", ".baz", ".123"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_discard_separator_regex() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||
)
|
||||
def test_character_text_splitter_discard_separator_regex(
|
||||
separator: str, is_separator_regex: bool
|
||||
) -> None:
|
||||
"""Test splitting by characters discarding the separator
|
||||
that is a regex special character."""
|
||||
text = "foo.bar.baz.123"
|
||||
splitter = CharacterTextSplitter(
|
||||
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False
|
||||
separator=separator,
|
||||
chunk_size=1,
|
||||
chunk_overlap=0,
|
||||
keep_separator=False,
|
||||
is_separator_regex=is_separator_regex,
|
||||
)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo", "bar", "baz", "123"]
|
||||
|
Loading…
Reference in New Issue
Block a user