From ac3e6e394406ff95dcd47c603db62a2ccbd00b81 Mon Sep 17 00:00:00 2001 From: Thomas B Date: Sun, 11 Jun 2023 01:48:53 +0200 Subject: [PATCH] Fix IndexError in RecursiveCharacterTextSplitter (#5902) Fixes (not reported) an error that may occur in some cases in the RecursiveCharacterTextSplitter. An empty `new_separators` array ([]) would end up in the else path of the condition below and used in a function where it is expected to be non empty. ```python if new_separators is None: ... else: # _split_text() expects this array to be non-empty! other_info = self._split_text(s, new_separators) ``` resulting in an `IndexError` ```python def _split_text(self, text: str, separators: List[str]) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use > separator = separators[-1] E IndexError: list index out of range langchain/text_splitter.py:425: IndexError ``` #### Who can review? @hwchase17 @eyurtsev --------- Co-authored-by: Harrison Chase --- langchain/text_splitter.py | 4 +-- tests/unit_tests/test_text_splitter.py | 44 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 7c875682..89559505 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -439,7 +439,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): final_chunks = [] # Get appropriate separator to use separator = separators[-1] - new_separators = None + new_separators = [] for i, _s in enumerate(separators): if _s == "": separator = _s @@ -461,7 +461,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] - if new_separators is None: + if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 2da634cd..91730b03 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -1,4 +1,6 @@ """Test text splitting functionality.""" +from typing import List + import pytest from langchain.docstore.document import Document @@ -148,6 +150,48 @@ def test_metadata_not_shallow() -> None: assert docs[1].metadata == {"source": "1"} +def test_iterative_text_splitter_keep_separator() -> None: + chunk_size = 5 + output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=True) + + assert output == [ + "....5", + "X..3", + "Y...4", + "X....5", + "Y...", + ] + + +def test_iterative_text_splitter_discard_separator() -> None: + chunk_size = 5 + output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=False) + + assert output == [ + "....5", + "..3", + "...4", + "....5", + "...", + ] + + +def __test_iterative_text_splitter(chunk_size: int, keep_separator: bool) -> List[str]: + chunk_size += 1 if keep_separator else 0 + + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=0, + separators=["X", "Y"], + keep_separator=keep_separator, + ) + text = "....5X..3Y...4X....5Y..." + output = splitter.split_text(text) + for chunk in output: + assert len(chunk) <= chunk_size, f"Chunk is larger than {chunk_size}" + return output + + def test_iterative_text_splitter() -> None: """Test iterative text splitter.""" text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.