refactor: extract token text splitter function (#5179)

# Token text splitter for sentence transformers

The current TokenTextSplitter only works with OpenAi models via the
`tiktoken` package. This is not clear from the name `TokenTextSplitter`.
In this (first PR) a token based text splitter for sentence transformer
models is added. In the future I think we should work towards injecting
a tokenizer into the TokenTextSplitter to make ti more flexible.
Could perhaps be reviewed by @dev2049

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
searx_updates
Jens Madsen 12 months ago committed by GitHub
parent 26ec845921
commit 8d9e9e013c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "73dbcdb9",
"metadata": {},
"source": [
"# SentenceTransformersTokenTextSplitter\n",
"\n",
"This notebook demonstrates how to use the `SentenceTransformersTokenTextSplitter` text splitter.\n",
"\n",
"Language models have a token limit. You should not exceed the token limit. When you split your text into chunks it is therefore a good idea to count the number of tokens. There are many tokenizers. When you count tokens in your text you should use the same tokenizer as used in the language model. \n",
"\n",
"The `SentenceTransformersTokenTextSplitter` is a specialized text splitter for use with the sentence-transformer models. The default behaviour is to split the text into chunks that fit the token window of the sentence transformer model that you would like to use."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9dd5419e",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import SentenceTransformersTokenTextSplitter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b43e5d54",
"metadata": {},
"outputs": [],
"source": [
"splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)\n",
"text = \"Lorem \""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1df84cb4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
}
],
"source": [
"count_start_and_stop_tokens = 2\n",
"text_token_count = splitter.count_tokens(text=text) - count_start_and_stop_tokens\n",
"print(text_token_count)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d7ad2213",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tokens in text to split: 514\n"
]
}
],
"source": [
"token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1\n",
"\n",
"# `text_to_split` does not fit in a single chunk\n",
"text_to_split = text * token_multiplier\n",
"\n",
"print(f\"tokens in text to split: {splitter.count_tokens(text=text_to_split)}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "818aea04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"lorem\n"
]
}
],
"source": [
"text_chunks = splitter.split_text(text=text_to_split)\n",
"\n",
"print(text_chunks[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9ba4f23",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -5,6 +5,7 @@ import copy
import logging
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import (
AbstractSet,
@ -244,6 +245,31 @@ class CharacterTextSplitter(TextSplitter):
return self._merge_splits(splits, _separator)
# should be in newer Python versions (3.10+)
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
chunk_overlap: int
tokens_per_chunk: int
decode: Callable[[list[int]], str]
encode: Callable[[str], List[int]]
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(tokenizer.decode(chunk_ids))
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
@ -275,22 +301,84 @@ class TokenTextSplitter(TextSplitter):
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
input_ids = self._tokenizer.encode(
def _encode(_text: str) -> List[int]:
return self._tokenizer.encode(
_text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self._chunk_size,
decode=self._tokenizer.decode,
encode=_encode,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
class SentenceTransformersTokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(
self,
chunk_overlap: int = 50,
model_name: str = "sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk: Optional[int] = None,
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
from transformers import AutoTokenizer
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
else:
self.tokens_per_chunk = tokens_per_chunk
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
raise ValueError(
f"The token limit of the models '{self.model_name}'"
f" is: {self.maximum_tokens_per_chunk}."
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
f" > maximum token limit."
)
def split_text(self, text: str) -> List[str]:
def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
return self._encode(text)[1:-1]
tokenizer = Tokenizer(
chunk_overlap=self._chunk_overlap,
tokens_per_chunk=self.tokens_per_chunk,
decode=self.tokenizer.decode,
encode=encode_strip_start_and_stop_token_ids,
)
return split_text_on_tokens(text=text, tokenizer=tokenizer)
def count_tokens(self, *, text: str) -> int:
return len(self._encode(text))
_max_length_equal_32_bit_integer = 2**32
def _encode(self, text: str) -> List[int]:
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
max_length=self._max_length_equal_32_bit_integer,
truncation="do_not_truncate",
)
start_idx = 0
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(self._tokenizer.decode(chunk_ids))
start_idx += self._chunk_size - self._chunk_overlap
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
return token_ids_with_start_and_end_token_ids
class Language(str, Enum):

@ -2,7 +2,11 @@
import pytest
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
from langchain.text_splitter import (
CharacterTextSplitter,
SentenceTransformersTokenTextSplitter,
TokenTextSplitter,
)
def test_huggingface_type_check() -> None:
@ -44,3 +48,45 @@ def test_token_text_splitter_from_tiktoken() -> None:
expected_tokenizer = "cl100k_base"
actual_tokenizer = splitter._tokenizer.name
assert expected_tokenizer == actual_tokenizer
def test_sentence_transformers_count_tokens() -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
text = "Lorem ipsum"
token_count = splitter.count_tokens(text=text)
expected_start_stop_token_count = 2
expected_text_token_count = 2
expected_token_count = expected_start_stop_token_count + expected_text_token_count
assert expected_token_count == token_count
def test_sentence_transformers_split_text() -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
text = "Lorem ipsum"
text_chunks = splitter.split_text(text=text)
expected_text_chunks = [text]
assert expected_text_chunks == text_chunks
def test_sentence_transformers_multiple_tokens() -> None:
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
text = "Lorem "
count_start_and_end_tokens = 2
text_token_count = splitter.count_tokens(text=text) - count_start_and_end_tokens
token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1
text_chunks = splitter.split_text(text=text * token_multiplier)
expected_number_of_chunks = 2
assert expected_number_of_chunks == len(text_chunks)
actual = splitter.count_tokens(text=text_chunks[1]) - count_start_and_end_tokens
expected = token_multiplier * text_token_count - splitter.maximum_tokens_per_chunk
assert expected == actual

Loading…
Cancel
Save