diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 6012b838..7b3bc326 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -71,12 +71,17 @@ class TextSplitter(ABC): def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. + separator_len = self._length_function(separator) + docs = [] current_doc: List[str] = [] total = 0 for d in splits: _len = self._length_function(d) - if total + _len >= self._chunk_size: + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " @@ -90,12 +95,16 @@ class TextSplitter(ABC): # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( - total + _len > self._chunk_size and total > 0 + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 ): - total -= self._length_function(current_doc[0]) + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) current_doc = current_doc[1:] current_doc.append(d) - total += _len + total += _len + (separator_len if len(current_doc) > 1 else 0) doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 89eaa3f0..90c21372 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -26,6 +26,15 @@ def test_character_text_splitter_empty_doc() -> None: assert output == expected_output +def test_character_text_splitter_separtor_empty_doc() -> None: + """Test edge cases are separators.""" + text = "f b" + splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0) + output = splitter.split_text(text) + expected_output = ["f", "b"] + assert output == expected_output + + def test_character_text_splitter_long() -> None: """Test splitting by character count on long words.""" text = "foo bar baz a a" @@ -99,7 +108,7 @@ Bye!\n\n-H.""" "Harrison.", "How? Are?", "You?", - "Okay then f", + "Okay then", "f f f f.", "This is a", "a weird", @@ -107,8 +116,8 @@ Bye!\n\n-H.""" "write, but", "gotta test", "the", - "splitting", - "gggg", + "splittingg", + "ggg", "some how.", "Bye!\n\n-H.", ]