Fix IndexError in RecursiveCharacterTextSplitter (#5902)

Fixes (not reported) an error that may occur in some cases in the
RecursiveCharacterTextSplitter.

An empty `new_separators` array ([]) would end up in the else path of
the condition below and used in a function where it is expected to be
non empty.

```python
if new_separators is None:
    ...
else:
   # _split_text() expects this array to be non-empty!
   other_info = self._split_text(s, new_separators)

```
resulting in an `IndexError`

```python
def _split_text(self, text: str, separators: List[str]) -> List[str]:
        """Split incoming text and return chunks."""
        final_chunks = []
        # Get appropriate separator to use
>       separator = separators[-1]
E       IndexError: list index out of range

langchain/text_splitter.py:425: IndexError
```

#### Who can review?
@hwchase17 @eyurtsev

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Thomas B 2023-06-11 01:48:53 +02:00 committed by GitHub
parent d2270a2261
commit ac3e6e3944
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 46 additions and 2 deletions

View File

@ -439,7 +439,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
final_chunks = [] final_chunks = []
# Get appropriate separator to use # Get appropriate separator to use
separator = separators[-1] separator = separators[-1]
new_separators = None new_separators = []
for i, _s in enumerate(separators): for i, _s in enumerate(separators):
if _s == "": if _s == "":
separator = _s separator = _s
@ -461,7 +461,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
merged_text = self._merge_splits(_good_splits, _separator) merged_text = self._merge_splits(_good_splits, _separator)
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
_good_splits = [] _good_splits = []
if new_separators is None: if not new_separators:
final_chunks.append(s) final_chunks.append(s)
else: else:
other_info = self._split_text(s, new_separators) other_info = self._split_text(s, new_separators)

View File

@ -1,4 +1,6 @@
"""Test text splitting functionality.""" """Test text splitting functionality."""
from typing import List
import pytest import pytest
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -148,6 +150,48 @@ def test_metadata_not_shallow() -> None:
assert docs[1].metadata == {"source": "1"} assert docs[1].metadata == {"source": "1"}
def test_iterative_text_splitter_keep_separator() -> None:
chunk_size = 5
output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=True)
assert output == [
"....5",
"X..3",
"Y...4",
"X....5",
"Y...",
]
def test_iterative_text_splitter_discard_separator() -> None:
chunk_size = 5
output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=False)
assert output == [
"....5",
"..3",
"...4",
"....5",
"...",
]
def __test_iterative_text_splitter(chunk_size: int, keep_separator: bool) -> List[str]:
chunk_size += 1 if keep_separator else 0
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=0,
separators=["X", "Y"],
keep_separator=keep_separator,
)
text = "....5X..3Y...4X....5Y..."
output = splitter.split_text(text)
for chunk in output:
assert len(chunk) <= chunk_size, f"Chunk is larger than {chunk_size}"
return output
def test_iterative_text_splitter() -> None: def test_iterative_text_splitter() -> None:
"""Test iterative text splitter.""" """Test iterative text splitter."""
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f. text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.