mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
allow tokentextsplitters to use model name to select encoder (#2963)
Fixes a bug I was seeing when the `TokenTextSplitter` was correctly splitting text under the gpt3.5-turbo token limit, but when firing the prompt off too openai, it'd come back with an error that we were over the context limit. gpt3.5-turbo and gpt-4 use `cl100k_base` tokenizer, and so the counts are just always off with the default `gpt-2` encoder. It's possible to pass along the encoding to the `TokenTextSplitter`, but it's much simpler to pass the model name of the LLM. No more concern about keeping the tokenizer and llm model in sync :)
This commit is contained in:
parent
706ebd8f9c
commit
51894ddd98
@ -139,6 +139,7 @@ class TextSplitter(ABC):
|
||||
def from_tiktoken_encoder(
|
||||
cls,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
@ -153,7 +154,9 @@ class TextSplitter(ABC):
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
# create a GPT-3 encoder instance
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
||||
@ -193,6 +196,7 @@ class TokenTextSplitter(TextSplitter):
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
@ -207,8 +211,12 @@ class TokenTextSplitter(TextSplitter):
|
||||
"This is needed in order to for TokenTextSplitter. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
# create a GPT-3 encoder instance
|
||||
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
self._tokenizer = enc
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user