mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
fix: use model token limit not tokenizer ditto (#5939)
This fixes a token limit bug in the SentenceTransformersTokenTextSplitter. Before the token limit was taken from tokenizer used by the model. However, for some models the token limit of the tokenizer (from `AutoTokenizer.from_pretrained`) does not equal the token limit of the model. This was a false assumption. Therefore, the token limit of the text splitter is now taken from the sentence transformers model token limit. Twitter: @plasmajens #### Before submitting #### Who can review? @hwchase17 and/or @dev2049 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
f8cf09a230
commit
1250cd4630
@ -20,6 +20,7 @@ from typing import (
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
@ -59,7 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
length_function: Callable[[str], int] = len,
|
||||
keep_separator: bool = False,
|
||||
add_start_index: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Create a new TextSplitter.
|
||||
|
||||
Args:
|
||||
@ -240,7 +241,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
self._separator = separator
|
||||
@ -265,7 +266,7 @@ class Tokenizer:
|
||||
|
||||
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = []
|
||||
splits: List[str] = []
|
||||
input_ids = tokenizer.encode(text)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
||||
@ -288,7 +289,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
@ -335,19 +336,28 @@ class SentenceTransformersTokenTextSplitter(TextSplitter):
|
||||
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
||||
tokens_per_chunk: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import sentence_transformer python package. "
|
||||
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
||||
"Please install it with `pip install sentence-transformers`."
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
self.tokenizer = self._model.tokenizer
|
||||
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
||||
|
||||
def _initialize_chunk_configuration(
|
||||
self, *, tokens_per_chunk: Optional[int]
|
||||
) -> None:
|
||||
self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence
|
||||
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
|
||||
|
||||
if tokens_per_chunk is None:
|
||||
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
||||
@ -419,7 +429,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
separators: Optional[List[str]] = None,
|
||||
keep_separator: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> None:
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""]
|
||||
@ -785,7 +795,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
|
||||
class NLTKTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at sentences using NLTK."""
|
||||
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
||||
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
||||
"""Initialize the NLTK splitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
@ -810,7 +820,7 @@ class SpacyTextSplitter(TextSplitter):
|
||||
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
|
||||
):
|
||||
) -> None:
|
||||
"""Initialize the spacy text splitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
@ -832,7 +842,7 @@ class SpacyTextSplitter(TextSplitter):
|
||||
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""Attempts to split the text along Python syntax."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a PythonCodeTextSplitter."""
|
||||
separators = self.get_separators_for_language(Language.PYTHON)
|
||||
super().__init__(separators=separators, **kwargs)
|
||||
@ -841,7 +851,7 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
|
||||
class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""Attempts to split the text along Markdown-formatted headings."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a MarkdownTextSplitter."""
|
||||
separators = self.get_separators_for_language(Language.MARKDOWN)
|
||||
super().__init__(separators=separators, **kwargs)
|
||||
@ -850,7 +860,7 @@ class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
|
||||
class LatexTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""Attempts to split the text along Latex-formatted layout elements."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize a LatexTextSplitter."""
|
||||
separators = self.get_separators_for_language(Language.LATEX)
|
||||
super().__init__(separators=separators, **kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user