Fix TextSplitter.from_tiktoken(#4361)

Thanks to @danb27 for the fix! Minor update

Fixes https://github.com/hwchase17/langchain/issues/4357

---------

Co-authored-by: Dan Bianchini <42096328+danb27@users.noreply.github.com>
parallel_dir_loader
Davis Chase 1 year ago committed by GitHub
parent 782df1db10
commit 02ebb15c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,8 @@ from typing import (
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Type,
TypeVar,
Union, Union,
) )
@ -22,6 +24,8 @@ from langchain.schema import BaseDocumentTransformer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
class TextSplitter(BaseDocumentTransformer, ABC): class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks.""" """Interface for splitting text into chunks."""
@ -139,13 +143,13 @@ class TextSplitter(BaseDocumentTransformer, ABC):
@classmethod @classmethod
def from_tiktoken_encoder( def from_tiktoken_encoder(
cls, cls: Type[TS],
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: Optional[str] = None, model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any, **kwargs: Any,
) -> TextSplitter: ) -> TS:
"""Text splitter that uses tiktoken encoder to count length.""" """Text splitter that uses tiktoken encoder to count length."""
try: try:
import tiktoken import tiktoken
@ -161,16 +165,24 @@ class TextSplitter(BaseDocumentTransformer, ABC):
else: else:
enc = tiktoken.get_encoding(encoding_name) enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str, **kwargs: Any) -> int: def _tiktoken_encoder(text: str) -> int:
return len( return len(
enc.encode( enc.encode(
text, text,
allowed_special=allowed_special, allowed_special=allowed_special,
disallowed_special=disallowed_special, disallowed_special=disallowed_special,
**kwargs,
) )
) )
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs) return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents( def transform_documents(

@ -23,19 +23,24 @@ def test_huggingface_tokenizer() -> None:
assert output == ["foo", "bar"] assert output == ["foo", "bar"]
class TestTokenTextSplitter: def test_token_text_splitter() -> None:
"""Test token text splitter.""" """Test no overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
def test_basic(self) -> None: output = splitter.split_text("abcdef" * 5) # 10 token string
"""Test no overlap.""" expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) assert output == expected_output
output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
assert output == expected_output def test_token_text_splitter_overlap() -> None:
"""Test with overlap."""
def test_overlap(self) -> None: splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
"""Test with overlap.""" output = splitter.split_text("abcdef" * 5) # 10 token string
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
output = splitter.split_text("abcdef" * 5) # 10 token string assert output == expected_output
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
assert output == expected_output
def test_token_text_splitter_from_tiktoken() -> None:
splitter = TokenTextSplitter.from_tiktoken_encoder(model_name="gpt-3.5-turbo")
expected_tokenizer = "cl100k_base"
actual_tokenizer = splitter._tokenizer.name
assert expected_tokenizer == actual_tokenizer

Loading…
Cancel
Save