diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 7c875682..89559505 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -439,7 +439,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): final_chunks = [] # Get appropriate separator to use separator = separators[-1] - new_separators = None + new_separators = [] for i, _s in enumerate(separators): if _s == "": separator = _s @@ -461,7 +461,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): merged_text = self._merge_splits(_good_splits, _separator) final_chunks.extend(merged_text) _good_splits = [] - if new_separators is None: + if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 2da634cd..91730b03 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -1,4 +1,6 @@ """Test text splitting functionality.""" +from typing import List + import pytest from langchain.docstore.document import Document @@ -148,6 +150,48 @@ def test_metadata_not_shallow() -> None: assert docs[1].metadata == {"source": "1"} +def test_iterative_text_splitter_keep_separator() -> None: + chunk_size = 5 + output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=True) + + assert output == [ + "....5", + "X..3", + "Y...4", + "X....5", + "Y...", + ] + + +def test_iterative_text_splitter_discard_separator() -> None: + chunk_size = 5 + output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=False) + + assert output == [ + "....5", + "..3", + "...4", + "....5", + "...", + ] + + +def __test_iterative_text_splitter(chunk_size: int, keep_separator: bool) -> List[str]: + chunk_size += 1 if keep_separator else 0 + + splitter = RecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=0, + separators=["X", "Y"], + keep_separator=keep_separator, + ) + text = "....5X..3Y...4X....5Y..." + output = splitter.split_text(text) + for chunk in output: + assert len(chunk) <= chunk_size, f"Chunk is larger than {chunk_size}" + return output + + def test_iterative_text_splitter() -> None: """Test iterative text splitter.""" text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.