forked from Archives/langchain
Fix TextSplitter.from_tiktoken(#4361)
Thanks to @danb27 for the fix! Minor update Fixes https://github.com/hwchase17/langchain/issues/4357 --------- Co-authored-by: Dan Bianchini <42096328+danb27@users.noreply.github.com>
This commit is contained in:
parent
782df1db10
commit
02ebb15c4a
@ -14,6 +14,8 @@ from typing import (
|
|||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,6 +24,8 @@ from langchain.schema import BaseDocumentTransformer
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TS = TypeVar("TS", bound="TextSplitter")
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter(BaseDocumentTransformer, ABC):
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
||||||
"""Interface for splitting text into chunks."""
|
"""Interface for splitting text into chunks."""
|
||||||
@ -139,13 +143,13 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiktoken_encoder(
|
def from_tiktoken_encoder(
|
||||||
cls,
|
cls: Type[TS],
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> TextSplitter:
|
) -> TS:
|
||||||
"""Text splitter that uses tiktoken encoder to count length."""
|
"""Text splitter that uses tiktoken encoder to count length."""
|
||||||
try:
|
try:
|
||||||
import tiktoken
|
import tiktoken
|
||||||
@ -161,16 +165,24 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
else:
|
else:
|
||||||
enc = tiktoken.get_encoding(encoding_name)
|
enc = tiktoken.get_encoding(encoding_name)
|
||||||
|
|
||||||
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
def _tiktoken_encoder(text: str) -> int:
|
||||||
return len(
|
return len(
|
||||||
enc.encode(
|
enc.encode(
|
||||||
text,
|
text,
|
||||||
allowed_special=allowed_special,
|
allowed_special=allowed_special,
|
||||||
disallowed_special=disallowed_special,
|
disallowed_special=disallowed_special,
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if issubclass(cls, TokenTextSplitter):
|
||||||
|
extra_kwargs = {
|
||||||
|
"encoding_name": encoding_name,
|
||||||
|
"model_name": model_name,
|
||||||
|
"allowed_special": allowed_special,
|
||||||
|
"disallowed_special": disallowed_special,
|
||||||
|
}
|
||||||
|
kwargs = {**kwargs, **extra_kwargs}
|
||||||
|
|
||||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||||
|
|
||||||
def transform_documents(
|
def transform_documents(
|
||||||
|
@ -23,19 +23,24 @@ def test_huggingface_tokenizer() -> None:
|
|||||||
assert output == ["foo", "bar"]
|
assert output == ["foo", "bar"]
|
||||||
|
|
||||||
|
|
||||||
class TestTokenTextSplitter:
|
def test_token_text_splitter() -> None:
|
||||||
"""Test token text splitter."""
|
|
||||||
|
|
||||||
def test_basic(self) -> None:
|
|
||||||
"""Test no overlap."""
|
"""Test no overlap."""
|
||||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
|
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=0)
|
||||||
output = splitter.split_text("abcdef" * 5) # 10 token string
|
output = splitter.split_text("abcdef" * 5) # 10 token string
|
||||||
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
|
expected_output = ["abcdefabcdefabc", "defabcdefabcdef"]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
def test_overlap(self) -> None:
|
|
||||||
|
def test_token_text_splitter_overlap() -> None:
|
||||||
"""Test with overlap."""
|
"""Test with overlap."""
|
||||||
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
|
splitter = TokenTextSplitter(chunk_size=5, chunk_overlap=1)
|
||||||
output = splitter.split_text("abcdef" * 5) # 10 token string
|
output = splitter.split_text("abcdef" * 5) # 10 token string
|
||||||
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
|
expected_output = ["abcdefabcdefabc", "abcdefabcdefabc", "abcdef"]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_text_splitter_from_tiktoken() -> None:
|
||||||
|
splitter = TokenTextSplitter.from_tiktoken_encoder(model_name="gpt-3.5-turbo")
|
||||||
|
expected_tokenizer = "cl100k_base"
|
||||||
|
actual_tokenizer = splitter._tokenizer.name
|
||||||
|
assert expected_tokenizer == actual_tokenizer
|
||||||
|
Loading…
Reference in New Issue
Block a user