Add alternative token-based text splitter (#816)

This does not involve a separator, and will naively chunk input text at
the appropriate boundaries in token space.

This is helpful if we have strict token length limits that we need to
strictly follow the specified chunk size, and we can't use aggressive
separators like spaces to guarantee the absence of long strings.

CharacterTextSplitter will let these strings through without splitting
them, which could cause overflow errors downstream.

Splitting at arbitrary token boundaries is not ideal but is hopefully
mitigated by having a decent overlap quantity. Also this results in
chunks which has exact number of tokens desired, instead of sometimes
overcounting if we concatenate shorter strings.

Potentially also helps with #528.
makefile-update-1
kahkeng 1 year ago committed by GitHub
parent 523ad2e6bd
commit 4a8f5cdf4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -146,6 +146,38 @@ class CharacterTextSplitter(TextSplitter):
return self._merge_splits(splits, self._separator)
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
self._tokenizer = tiktoken.get_encoding(encoding_name)
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
input_ids = self._tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(self._tokenizer.decode(chunk_ids))
start_idx += self._chunk_size - self._chunk_overlap
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
class RecursiveCharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.

@ -2,7 +2,7 @@
import pytest
from langchain.text_splitter import CharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
def test_huggingface_type_check() -> None:
@ -21,3 +21,21 @@ def test_huggingface_tokenizer() -> None:
)
output = text_splitter.split_text("foo bar")
assert output == ["foo", "bar"]
class TestTokenTextSplitter:
"""Test token text splitter."""
def test_basic(self) -> None:
"""Test no overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
assert output == expected_output
def test_overlap(self) -> None:
"""Test with overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
assert output == expected_output

Loading…
Cancel
Save