Fix TextSplitter.from_tiktoken(#4361)

Thanks to @danb27 for the fix! Minor update

Fixes https://github.com/hwchase17/langchain/issues/4357

---------

Co-authored-by: Dan Bianchini <42096328+danb27@users.noreply.github.com>
This commit is contained in:
Davis Chase 2023-05-08 16:36:38 -07:00 committed by GitHub
parent 782df1db10
commit 02ebb15c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 18 deletions

View File

@ -14,6 +14,8 @@ from typing import (
Literal, Literal,
Optional, Optional,
Sequence, Sequence,
Type,
TypeVar,
Union, Union,
) )
@ -22,6 +24,8 @@ from langchain.schema import BaseDocumentTransformer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
class TextSplitter(BaseDocumentTransformer, ABC): class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks.""" """Interface for splitting text into chunks."""
@ -139,13 +143,13 @@ class TextSplitter(BaseDocumentTransformer, ABC):
@classmethod @classmethod
def from_tiktoken_encoder( def from_tiktoken_encoder(
cls, cls: Type[TS],
encoding_name: str = "gpt2", encoding_name: str = "gpt2",
model_name: Optional[str] = None, model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any, **kwargs: Any,
) -> TextSplitter: ) -> TS:
"""Text splitter that uses tiktoken encoder to count length.""" """Text splitter that uses tiktoken encoder to count length."""
try: try:
import tiktoken import tiktoken
@ -161,16 +165,24 @@ class TextSplitter(BaseDocumentTransformer, ABC):
else: else:
enc = tiktoken.get_encoding(encoding_name) enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str, **kwargs: Any) -> int: def _tiktoken_encoder(text: str) -> int:
return len( return len(
enc.encode( enc.encode(
text, text,
allowed_special=allowed_special, allowed_special=allowed_special,
disallowed_special=disallowed_special, disallowed_special=disallowed_special,
**kwargs,
) )
) )
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs) return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents( def transform_documents(

View File

@ -23,19 +23,24 @@ def test_huggingface_tokenizer() -> None:
assert output == ["foo", "bar"] assert output == ["foo", "bar"]
class TestTokenTextSplitter: def test_token_text_splitter() -> None:
"""Test token text splitter."""
def test_basic(self) -> None:
"""Test no overlap.""" """Test no overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
output = splitter.split_text("abcdef" * 5) # 10 token string output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"] expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
assert output == expected_output assert output == expected_output
def test_overlap(self) -> None:
def test_token_text_splitter_overlap() -> None:
"""Test with overlap.""" """Test with overlap."""
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
output = splitter.split_text("abcdef" * 5) # 10 token string output = splitter.split_text("abcdef" * 5) # 10 token string
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"] expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
assert output == expected_output assert output == expected_output
def test_token_text_splitter_from_tiktoken() -> None:
splitter = TokenTextSplitter.from_tiktoken_encoder(model_name="gpt-3.5-turbo")
expected_tokenizer = "cl100k_base"
actual_tokenizer = splitter._tokenizer.name
assert expected_tokenizer == actual_tokenizer