diff --git a/docs/modules/indexes/text_splitters/examples/sentence_transformer_token_splitter.ipynb b/docs/modules/indexes/text_splitters/examples/sentence_transformer_token_splitter.ipynb new file mode 100644 index 00000000..5b64c053 --- /dev/null +++ b/docs/modules/indexes/text_splitters/examples/sentence_transformer_token_splitter.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "73dbcdb9", + "metadata": {}, + "source": [ + "# SentenceTransformersTokenTextSplitter\n", + "\n", + "This notebook demonstrates how to use the `SentenceTransformersTokenTextSplitter` text splitter.\n", + "\n", + "Language models have a token limit. You should not exceed the token limit. When you split your text into chunks it is therefore a good idea to count the number of tokens. There are many tokenizers. When you count tokens in your text you should use the same tokenizer as used in the language model. \n", + "\n", + "The `SentenceTransformersTokenTextSplitter` is a specialized text splitter for use with the sentence-transformer models. The default behaviour is to split the text into chunks that fit the token window of the sentence transformer model that you would like to use." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dd5419e", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import SentenceTransformersTokenTextSplitter" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b43e5d54", + "metadata": {}, + "outputs": [], + "source": [ + "splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)\n", + "text = \"Lorem \"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1df84cb4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n" + ] + } + ], + "source": [ + "count_start_and_stop_tokens = 2\n", + "text_token_count = splitter.count_tokens(text=text) - count_start_and_stop_tokens\n", + "print(text_token_count)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d7ad2213", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tokens in text to split: 514\n" + ] + } + ], + "source": [ + "token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1\n", + "\n", + "# `text_to_split` does not fit in a single chunk\n", + "text_to_split = text * token_multiplier\n", + "\n", + "print(f\"tokens in text to split: {splitter.count_tokens(text=text_to_split)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "818aea04", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lorem\n" + ] + } + ], + "source": [ + "text_chunks = splitter.split_text(text=text_to_split)\n", + "\n", + "print(text_chunks[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9ba4f23", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 560b2c5a..77f34d06 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -5,6 +5,7 @@ import copy import logging import re from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import ( AbstractSet, @@ -244,6 +245,31 @@ class CharacterTextSplitter(TextSplitter): return self._merge_splits(splits, _separator) +# should be in newer Python versions (3.10+) +# @dataclass(frozen=True, kw_only=True, slots=True) +@dataclass(frozen=True) +class Tokenizer: + chunk_overlap: int + tokens_per_chunk: int + decode: Callable[[list[int]], str] + encode: Callable[[str], List[int]] + + +def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: + """Split incoming text and return chunks.""" + splits = [] + input_ids = tokenizer.encode(text) + start_idx = 0 + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + while start_idx < len(input_ids): + splits.append(tokenizer.decode(chunk_ids)) + start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap + cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) + chunk_ids = input_ids[start_idx:cur_idx] + return splits + + class TokenTextSplitter(TextSplitter): """Implementation of splitting text that looks at tokens.""" @@ -275,22 +301,84 @@ class TokenTextSplitter(TextSplitter): self._disallowed_special = disallowed_special def split_text(self, text: str) -> List[str]: - """Split incoming text and return chunks.""" - splits = [] - input_ids = self._tokenizer.encode( + def _encode(_text: str) -> List[int]: + return self._tokenizer.encode( + _text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self._chunk_size, + decode=self._tokenizer.decode, + encode=_encode, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + +class SentenceTransformersTokenTextSplitter(TextSplitter): + """Implementation of splitting text that looks at tokens.""" + + def __init__( + self, + chunk_overlap: int = 50, + model_name: str = "sentence-transformers/all-mpnet-base-v2", + tokens_per_chunk: Optional[int] = None, + **kwargs: Any, + ): + """Create a new TextSplitter.""" + super().__init__(**kwargs, chunk_overlap=chunk_overlap) + from transformers import AutoTokenizer + + self.model_name = model_name + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) + + def _initialize_chunk_configuration( + self, *, tokens_per_chunk: Optional[int] + ) -> None: + self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence + + if tokens_per_chunk is None: + self.tokens_per_chunk = self.maximum_tokens_per_chunk + else: + self.tokens_per_chunk = tokens_per_chunk + + if self.tokens_per_chunk > self.maximum_tokens_per_chunk: + raise ValueError( + f"The token limit of the models '{self.model_name}'" + f" is: {self.maximum_tokens_per_chunk}." + f" Argument tokens_per_chunk={self.tokens_per_chunk}" + f" > maximum token limit." + ) + + def split_text(self, text: str) -> List[str]: + def encode_strip_start_and_stop_token_ids(text: str) -> List[int]: + return self._encode(text)[1:-1] + + tokenizer = Tokenizer( + chunk_overlap=self._chunk_overlap, + tokens_per_chunk=self.tokens_per_chunk, + decode=self.tokenizer.decode, + encode=encode_strip_start_and_stop_token_ids, + ) + + return split_text_on_tokens(text=text, tokenizer=tokenizer) + + def count_tokens(self, *, text: str) -> int: + return len(self._encode(text)) + + _max_length_equal_32_bit_integer = 2**32 + + def _encode(self, text: str) -> List[int]: + token_ids_with_start_and_end_token_ids = self.tokenizer.encode( text, - allowed_special=self._allowed_special, - disallowed_special=self._disallowed_special, + max_length=self._max_length_equal_32_bit_integer, + truncation="do_not_truncate", ) - start_idx = 0 - cur_idx = min(start_idx + self._chunk_size, len(input_ids)) - chunk_ids = input_ids[start_idx:cur_idx] - while start_idx < len(input_ids): - splits.append(self._tokenizer.decode(chunk_ids)) - start_idx += self._chunk_size - self._chunk_overlap - cur_idx = min(start_idx + self._chunk_size, len(input_ids)) - chunk_ids = input_ids[start_idx:cur_idx] - return splits + return token_ids_with_start_and_end_token_ids class Language(str, Enum): diff --git a/tests/integration_tests/test_text_splitter.py b/tests/integration_tests/test_text_splitter.py index d19a58d5..3cf78c71 100644 --- a/tests/integration_tests/test_text_splitter.py +++ b/tests/integration_tests/test_text_splitter.py @@ -2,7 +2,11 @@ import pytest -from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter +from langchain.text_splitter import ( + CharacterTextSplitter, + SentenceTransformersTokenTextSplitter, + TokenTextSplitter, +) def test_huggingface_type_check() -> None: @@ -44,3 +48,45 @@ def test_token_text_splitter_from_tiktoken() -> None: expected_tokenizer = "cl100k_base" actual_tokenizer = splitter._tokenizer.name assert expected_tokenizer == actual_tokenizer + + +def test_sentence_transformers_count_tokens() -> None: + splitter = SentenceTransformersTokenTextSplitter( + model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + text = "Lorem ipsum" + + token_count = splitter.count_tokens(text=text) + + expected_start_stop_token_count = 2 + expected_text_token_count = 2 + expected_token_count = expected_start_stop_token_count + expected_text_token_count + + assert expected_token_count == token_count + + +def test_sentence_transformers_split_text() -> None: + splitter = SentenceTransformersTokenTextSplitter( + model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + text = "Lorem ipsum" + text_chunks = splitter.split_text(text=text) + expected_text_chunks = [text] + assert expected_text_chunks == text_chunks + + +def test_sentence_transformers_multiple_tokens() -> None: + splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0) + text = "Lorem " + + count_start_and_end_tokens = 2 + text_token_count = splitter.count_tokens(text=text) - count_start_and_end_tokens + token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1 + text_chunks = splitter.split_text(text=text * token_multiplier) + + expected_number_of_chunks = 2 + + assert expected_number_of_chunks == len(text_chunks) + actual = splitter.count_tokens(text=text_chunks[1]) - count_start_and_end_tokens + expected = token_multiplier * text_token_count - splitter.maximum_tokens_per_chunk + assert expected == actual