fix text splitter (#375)

harrison/sequential_chain_from_prompts
Harrison Chase 1 year ago committed by GitHub
parent 3474f39e21
commit e7b625fe03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,12 @@
"""Functionality for splitting text."""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List
logger = logging.getLogger()
class TextSplitter(ABC):
"""Interface for splitting text into chunks."""
@ -37,13 +40,20 @@ class TextSplitter(ABC):
current_doc: List[str] = []
total = 0
for d in splits:
if total >= self._chunk_size:
docs.append(self._separator.join(current_doc))
while total > self._chunk_overlap:
total -= self._length_function(current_doc[0])
current_doc = current_doc[1:]
_len = self._length_function(d)
if total + _len >= self._chunk_size:
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
docs.append(self._separator.join(current_doc))
while total > self._chunk_overlap:
total -= self._length_function(current_doc[0])
current_doc = current_doc[1:]
current_doc.append(d)
total += self._length_function(d)
total += _len
docs.append(self._separator.join(current_doc))
return docs

@ -7,12 +7,21 @@ from langchain.text_splitter import CharacterTextSplitter
def test_character_text_splitter() -> None:
"""Test splitting by character count."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=5, chunk_overlap=3)
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
output = splitter.split_text(text)
expected_output = ["foo bar", "bar baz", "baz 123"]
assert output == expected_output
def test_character_text_splitter_long() -> None:
"""Test splitting by character count on long words."""
text = "foo bar baz a a"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "a a"]
assert output == expected_output
def test_character_text_splitter_longer_words() -> None:
"""Test splitting by characters when splits not found easily."""
text = "foo bar baz 123"

Loading…
Cancel
Save