fix: use model token limit not tokenizer ditto (#5939)

This fixes a token limit bug in the
SentenceTransformersTokenTextSplitter. Before the token limit was taken
from tokenizer used by the model. However, for some models the token
limit of the tokenizer (from `AutoTokenizer.from_pretrained`) does not
equal the token limit of the model. This was a false assumption.
Therefore, the token limit of the text splitter is now taken from the
sentence transformers model token limit.

Twitter: @plasmajens

#### Before submitting

#### Who can review?

@hwchase17 and/or @dev2049

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Jens Madsen 2023-06-11 01:36:03 +02:00 committed by GitHub
parent f8cf09a230
commit 1250cd4630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -20,6 +20,7 @@ from typing import (
Type,
TypeVar,
Union,
cast,
)
from langchain.docstore.document import Document
@ -59,7 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
):
) -> None:
"""Create a new TextSplitter.
Args:
@ -240,7 +241,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any):
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
@ -265,7 +266,7 @@ class Tokenizer:
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
splits: List[str] = []
input_ids = tokenizer.encode(text)
start_idx = 0
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
@ -288,7 +289,7 @@ class TokenTextSplitter(TextSplitter):
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
@ -335,19 +336,28 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
model_name: str = "sentence-transformers/all-mpnet-base-v2",
tokens_per_chunk: Optional[int] = None,
**kwargs: Any,
):
) -> None:
"""Create a new TextSplitter."""
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
from transformers import AutoTokenizer
try:
from sentence_transformers import SentenceTransformer
except ImportError:
raise ImportError(
"Could not import sentence_transformer python package. "
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
"Please install it with `pip install sentence-transformers`."
)
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self._model = SentenceTransformer(self.model_name)
self.tokenizer = self._model.tokenizer
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
def _initialize_chunk_configuration(
self, *, tokens_per_chunk: Optional[int]
) -> None:
self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
if tokens_per_chunk is None:
self.tokens_per_chunk = self.maximum_tokens_per_chunk
@ -419,7 +429,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
separators: Optional[List[str]] = None,
keep_separator: bool = True,
**kwargs: Any,
):
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
@ -785,7 +795,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK."""
def __init__(self, separator: str = "\n\n", **kwargs: Any):
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
try:
@ -810,7 +820,7 @@ class SpacyTextSplitter(TextSplitter):
def __init__(
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
):
) -> None:
"""Initialize the spacy text splitter."""
super().__init__(**kwargs)
try:
@ -832,7 +842,7 @@ class SpacyTextSplitter(TextSplitter):
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Python syntax."""
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs: Any) -> None:
"""Initialize a PythonCodeTextSplitter."""
separators = self.get_separators_for_language(Language.PYTHON)
super().__init__(separators=separators, **kwargs)
@ -841,7 +851,7 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Markdown-formatted headings."""
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs: Any) -> None:
"""Initialize a MarkdownTextSplitter."""
separators = self.get_separators_for_language(Language.MARKDOWN)
super().__init__(separators=separators, **kwargs)
@ -850,7 +860,7 @@ class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
class LatexTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Latex-formatted layout elements."""
def __init__(self, **kwargs: Any):
def __init__(self, **kwargs: Any) -> None:
"""Initialize a LatexTextSplitter."""
separators = self.get_separators_for_language(Language.LATEX)
super().__init__(separators=separators, **kwargs)