allow tokentextsplitters to use model name to select encoder (#2963)

Fixes a bug I was seeing when the `TokenTextSplitter` was correctly
splitting text under the gpt3.5-turbo token limit, but when firing the
prompt off too openai, it'd come back with an error that we were over
the context limit.

gpt3.5-turbo and gpt-4 use `cl100k_base` tokenizer, and so the counts
are just always off with the default `gpt-2` encoder.

It's possible to pass along the encoding to the `TokenTextSplitter`, but
it's much simpler to pass the model name of the LLM. No more concern
about keeping the tokenizer and llm model in sync :)
This commit is contained in:
Tim Asp 2023-04-16 08:33:47 -07:00 committed by GitHub
parent 706ebd8f9c
commit 51894ddd98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -139,6 +139,7 @@ class TextSplitter(ABC):
def from_tiktoken_encoder(
cls,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
@ -153,7 +154,9 @@ class TextSplitter(ABC):
"Please install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
@ -193,6 +196,7 @@ class TokenTextSplitter(TextSplitter):
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
@ -207,8 +211,12 @@ class TokenTextSplitter(TextSplitter):
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
# create a GPT-3 encoder instance
self._tokenizer = tiktoken.get_encoding(encoding_name)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special