From 02ebb15c4a92a23818c2c17486bdaf9f590dc6a5 Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 8 May 2023 16:36:38 -0700 Subject: [PATCH] Fix TextSplitter.from_tiktoken(#4361) Thanks to @danb27 for the fix! Minor update Fixes https://github.com/hwchase17/langchain/issues/4357 --------- Co-authored-by: Dan Bianchini <42096328+danb27@users.noreply.github.com> --- langchain/text_splitter.py | 20 ++++++++-- tests/integration_tests/test_text_splitter.py | 37 +++++++++++-------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 1a09bf05..06e1fc2a 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -14,6 +14,8 @@ from typing import ( Literal, Optional, Sequence, + Type, + TypeVar, Union, ) @@ -22,6 +24,8 @@ from langchain.schema import BaseDocumentTransformer logger = logging.getLogger(__name__) +TS = TypeVar("TS", bound="TextSplitter") + class TextSplitter(BaseDocumentTransformer, ABC): """Interface for splitting text into chunks.""" @@ -139,13 +143,13 @@ class TextSplitter(BaseDocumentTransformer, ABC): @classmethod def from_tiktoken_encoder( - cls, + cls: Type[TS], encoding_name: str = "gpt2", model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, - ) -> TextSplitter: + ) -> TS: """Text splitter that uses tiktoken encoder to count length.""" try: import tiktoken @@ -161,16 +165,24 @@ class TextSplitter(BaseDocumentTransformer, ABC): else: enc = tiktoken.get_encoding(encoding_name) - def _tiktoken_encoder(text: str, **kwargs: Any) -> int: + def _tiktoken_encoder(text: str) -> int: return len( enc.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, - **kwargs, ) ) + if issubclass(cls, TokenTextSplitter): + extra_kwargs = { + "encoding_name": encoding_name, + "model_name": model_name, + "allowed_special": allowed_special, + "disallowed_special": disallowed_special, + } + kwargs = {**kwargs, **extra_kwargs} + return cls(length_function=_tiktoken_encoder, **kwargs) def transform_documents( diff --git a/tests/integration_tests/test_text_splitter.py b/tests/integration_tests/test_text_splitter.py index 367899aa..d19a58d5 100644 --- a/tests/integration_tests/test_text_splitter.py +++ b/tests/integration_tests/test_text_splitter.py @@ -23,19 +23,24 @@ def test_huggingface_tokenizer() -> None: assert output == ["foo", "bar"] -class TestTokenTextSplitter: - """Test token text splitter.""" - - def test_basic(self) -> None: - """Test no overlap.""" - splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) - output = splitter.split_text("abcdef" * 5) # 10 token string - expected_output = ["abcdefabcdefabc", "defabcdefabcdef"] - assert output == expected_output - - def test_overlap(self) -> None: - """Test with overlap.""" - splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) - output = splitter.split_text("abcdef" * 5) # 10 token string - expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"] - assert output == expected_output +def test_token_text_splitter() -> None: + """Test no overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "defabcdefabcdef"] + assert output == expected_output + + +def test_token_text_splitter_overlap() -> None: + """Test with overlap.""" + splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1) + output = splitter.split_text("abcdef" * 5) # 10 token string + expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"] + assert output == expected_output + + +def test_token_text_splitter_from_tiktoken() -> None: + splitter = TokenTextSplitter.from_tiktoken_encoder(model_name="gpt-3.5-turbo") + expected_tokenizer = "cl100k_base" + actual_tokenizer = splitter._tokenizer.name + assert expected_tokenizer == actual_tokenizer