Harrison/tiktoken spec (#964)

Co-authored-by: James Briggs <35938317+jamescalam@users.noreply.github.com>
Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
makefile-update-1
Harrison Chase 1 year ago committed by GitHub
parent 5f8082bdd7
commit ba54d36787
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,7 +3,17 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List, Optional
from typing import (
AbstractSet,
Any,
Callable,
Collection,
Iterable,
List,
Literal,
Optional,
Union,
)
from langchain.docstore.document import Document
@ -114,7 +124,11 @@ class TextSplitter(ABC):
@classmethod
def from_tiktoken_encoder(
cls, encoding_name: str = "gpt2", **kwargs: Any
cls,
encoding_name: str = "gpt2",
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TextSplitter:
"""Text splitter that uses tiktoken encoder to count length."""
try:
@ -125,11 +139,19 @@ class TextSplitter(ABC):
"This is needed in order to calculate max_tokens_for_prompt. "
"Please it install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(enc.encode(text))
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
**kwargs,
)
)
return cls(length_function=_tiktoken_encoder, **kwargs)
@ -155,7 +177,13 @@ class CharacterTextSplitter(TextSplitter):
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
def __init__(
self,
encoding_name: str = "gpt2",
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
@ -168,11 +196,17 @@ class TokenTextSplitter(TextSplitter):
)
# create a GPT-3 encoder instance
self._tokenizer = tiktoken.get_encoding(encoding_name)
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
input_ids = self._tokenizer.encode(text)
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
start_idx = 0
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]

Loading…
Cancel
Save