From 51894ddd98b3201e52a81091dcd88d01e7089db7 Mon Sep 17 00:00:00 2001 From: Tim Asp <707699+timothyasp@users.noreply.github.com> Date: Sun, 16 Apr 2023 08:33:47 -0700 Subject: [PATCH] allow tokentextsplitters to use model name to select encoder (#2963) Fixes a bug I was seeing when the `TokenTextSplitter` was correctly splitting text under the gpt3.5-turbo token limit, but when firing the prompt off too openai, it'd come back with an error that we were over the context limit. gpt3.5-turbo and gpt-4 use `cl100k_base` tokenizer, and so the counts are just always off with the default `gpt-2` encoder. It's possible to pass along the encoding to the `TokenTextSplitter`, but it's much simpler to pass the model name of the LLM. No more concern about keeping the tokenizer and llm model in sync :) --- langchain/text_splitter.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index d189296b..37834ce6 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -139,6 +139,7 @@ class TextSplitter(ABC): def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", + model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -153,8 +154,10 @@ class TextSplitter(ABC): "Please install it with `pip install tiktoken`." ) - # create a GPT-3 encoder instance - enc = tiktoken.get_encoding(encoding_name) + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str, **kwargs: Any) -> int: return len( @@ -193,6 +196,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, encoding_name: str = "gpt2", + model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -207,8 +211,12 @@ class TokenTextSplitter(TextSplitter): "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) - # create a GPT-3 encoder instance - self._tokenizer = tiktoken.get_encoding(encoding_name) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special