mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add regex control over separators in character text splitter (#7933)
<!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #7854 Added the ability to use the `separator` ase a regex or a simple character. Fixed a bug where `start_index` was incorrectly counting from -1. Who can review? @eyurtsev @hwchase17 @mmz-001
This commit is contained in:
parent
e68a1d73d0
commit
6f0bccfeb5
@ -12,6 +12,7 @@ text_splitter = CharacterTextSplitter(
|
|||||||
chunk_size = 1000,
|
chunk_size = 1000,
|
||||||
chunk_overlap = 200,
|
chunk_overlap = 200,
|
||||||
length_function = len,
|
length_function = len,
|
||||||
|
is_separator_regex = False,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ text_splitter = RecursiveCharacterTextSplitter(
|
|||||||
chunk_size = 100,
|
chunk_size = 100,
|
||||||
chunk_overlap = 20,
|
chunk_overlap = 20,
|
||||||
length_function = len,
|
length_function = len,
|
||||||
|
is_separator_regex = False,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -281,15 +281,21 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
class CharacterTextSplitter(TextSplitter):
|
class CharacterTextSplitter(TextSplitter):
|
||||||
"""Splitting text that looks at characters."""
|
"""Splitting text that looks at characters."""
|
||||||
|
|
||||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
def __init__(
|
||||||
|
self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._separator = separator
|
self._separator = separator
|
||||||
|
self._is_separator_regex = is_separator_regex
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
# First we naively split the large input into a bunch of smaller ones.
|
# First we naively split the large input into a bunch of smaller ones.
|
||||||
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
separator = (
|
||||||
|
self._separator if self._is_separator_regex else re.escape(self._separator)
|
||||||
|
)
|
||||||
|
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
||||||
_separator = "" if self._keep_separator else self._separator
|
_separator = "" if self._keep_separator else self._separator
|
||||||
return self._merge_splits(splits, _separator)
|
return self._merge_splits(splits, _separator)
|
||||||
|
|
||||||
@ -629,11 +635,13 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
self,
|
self,
|
||||||
separators: Optional[List[str]] = None,
|
separators: Optional[List[str]] = None,
|
||||||
keep_separator: bool = True,
|
keep_separator: bool = True,
|
||||||
|
is_separator_regex: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||||
|
self._is_separator_regex = is_separator_regex
|
||||||
|
|
||||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
@ -642,15 +650,18 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
separator = separators[-1]
|
separator = separators[-1]
|
||||||
new_separators = []
|
new_separators = []
|
||||||
for i, _s in enumerate(separators):
|
for i, _s in enumerate(separators):
|
||||||
|
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||||||
if _s == "":
|
if _s == "":
|
||||||
separator = _s
|
separator = _s
|
||||||
break
|
break
|
||||||
if re.search(_s, text):
|
if re.search(_separator, text):
|
||||||
separator = _s
|
separator = _s
|
||||||
new_separators = separators[i + 1 :]
|
new_separators = separators[i + 1 :]
|
||||||
break
|
break
|
||||||
|
|
||||||
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
_separator = separator if self._is_separator_regex else re.escape(separator)
|
||||||
|
splits = _split_text_with_regex(text, _separator, self._keep_separator)
|
||||||
|
|
||||||
# Now go merging things, recursively splitting longer texts.
|
# Now go merging things, recursively splitting longer texts.
|
||||||
_good_splits = []
|
_good_splits = []
|
||||||
_separator = "" if self._keep_separator else separator
|
_separator = "" if self._keep_separator else separator
|
||||||
@ -680,7 +691,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
|||||||
cls, language: Language, **kwargs: Any
|
cls, language: Language, **kwargs: Any
|
||||||
) -> RecursiveCharacterTextSplitter:
|
) -> RecursiveCharacterTextSplitter:
|
||||||
separators = cls.get_separators_for_language(language)
|
separators = cls.get_separators_for_language(language)
|
||||||
return cls(separators=separators, **kwargs)
|
return cls(separators=separators, is_separator_regex=True, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_separators_for_language(language: Language) -> List[str]:
|
def get_separators_for_language(language: Language) -> List[str]:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test text splitting functionality."""
|
"""Test text splitting functionality."""
|
||||||
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -80,25 +81,43 @@ def test_character_text_splitter_longer_words() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_keep_separator_regex() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||||
|
)
|
||||||
|
def test_character_text_splitter_keep_separator_regex(
|
||||||
|
separator: str, is_separator_regex: bool
|
||||||
|
) -> None:
|
||||||
"""Test splitting by characters while keeping the separator
|
"""Test splitting by characters while keeping the separator
|
||||||
that is a regex special character.
|
that is a regex special character.
|
||||||
"""
|
"""
|
||||||
text = "foo.bar.baz.123"
|
text = "foo.bar.baz.123"
|
||||||
splitter = CharacterTextSplitter(
|
splitter = CharacterTextSplitter(
|
||||||
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=True
|
separator=separator,
|
||||||
|
chunk_size=1,
|
||||||
|
chunk_overlap=0,
|
||||||
|
keep_separator=True,
|
||||||
|
is_separator_regex=is_separator_regex,
|
||||||
)
|
)
|
||||||
output = splitter.split_text(text)
|
output = splitter.split_text(text)
|
||||||
expected_output = ["foo", ".bar", ".baz", ".123"]
|
expected_output = ["foo", ".bar", ".baz", ".123"]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_discard_separator_regex() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
"separator, is_separator_regex", [(re.escape("."), True), (".", False)]
|
||||||
|
)
|
||||||
|
def test_character_text_splitter_discard_separator_regex(
|
||||||
|
separator: str, is_separator_regex: bool
|
||||||
|
) -> None:
|
||||||
"""Test splitting by characters discarding the separator
|
"""Test splitting by characters discarding the separator
|
||||||
that is a regex special character."""
|
that is a regex special character."""
|
||||||
text = "foo.bar.baz.123"
|
text = "foo.bar.baz.123"
|
||||||
splitter = CharacterTextSplitter(
|
splitter = CharacterTextSplitter(
|
||||||
separator=r"\.", chunk_size=1, chunk_overlap=0, keep_separator=False
|
separator=separator,
|
||||||
|
chunk_size=1,
|
||||||
|
chunk_overlap=0,
|
||||||
|
keep_separator=False,
|
||||||
|
is_separator_regex=is_separator_regex,
|
||||||
)
|
)
|
||||||
output = splitter.split_text(text)
|
output = splitter.split_text(text)
|
||||||
expected_output = ["foo", "bar", "baz", "123"]
|
expected_output = ["foo", "bar", "baz", "123"]
|
||||||
|
Loading…
Reference in New Issue
Block a user