diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index daf720e1..51ba41df 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -1,9 +1,12 @@ """Functionality for splitting text.""" from __future__ import annotations +import logging from abc import ABC, abstractmethod from typing import Any, Callable, Iterable, List +logger = logging.getLogger() + class TextSplitter(ABC): """Interface for splitting text into chunks.""" @@ -37,13 +40,20 @@ class TextSplitter(ABC): current_doc: List[str] = [] total = 0 for d in splits: - if total >= self._chunk_size: - docs.append(self._separator.join(current_doc)) - while total > self._chunk_overlap: - total -= self._length_function(current_doc[0]) - current_doc = current_doc[1:] + _len = self._length_function(d) + if total + _len >= self._chunk_size: + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + docs.append(self._separator.join(current_doc)) + while total > self._chunk_overlap: + total -= self._length_function(current_doc[0]) + current_doc = current_doc[1:] current_doc.append(d) - total += self._length_function(d) + total += _len docs.append(self._separator.join(current_doc)) return docs diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index d3f90748..ffdf56b8 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -7,12 +7,21 @@ from langchain.text_splitter import CharacterTextSplitter def test_character_text_splitter() -> None: """Test splitting by character count.""" text = "foo bar baz 123" - splitter = CharacterTextSplitter(separator=" ", chunk_size=5, chunk_overlap=3) + splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3) output = splitter.split_text(text) expected_output = ["foo bar", "bar baz", "baz 123"] assert output == expected_output +def test_character_text_splitter_long() -> None: + """Test splitting by character count on long words.""" + text = "foo bar baz a a" + splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) + output = splitter.split_text(text) + expected_output = ["foo", "bar", "baz", "a a"] + assert output == expected_output + + def test_character_text_splitter_longer_words() -> None: """Test splitting by characters when splits not found easily.""" text = "foo bar baz 123"