|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from functools import lru_cache
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Any,
|
|
|
|
@ -23,10 +24,8 @@ if TYPE_CHECKING:
|
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_token_ids_default_method(text: str) -> List[int]:
|
|
|
|
|
"""Encode the text into token IDs."""
|
|
|
|
|
# TODO: this method may not be exact.
|
|
|
|
|
# TODO: this method may differ based on model (eg codex).
|
|
|
|
|
@lru_cache(maxsize=None) # Cache the tokenizer
|
|
|
|
|
def get_tokenizer() -> Any:
|
|
|
|
|
try:
|
|
|
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
|
except ImportError:
|
|
|
|
@ -36,7 +35,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
|
|
|
|
|
"Please install it with `pip install transformers`."
|
|
|
|
|
)
|
|
|
|
|
# create a GPT-2 tokenizer instance
|
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
|
return GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_token_ids_default_method(text: str) -> List[int]:
|
|
|
|
|
"""Encode the text into token IDs."""
|
|
|
|
|
# get the cached tokenizer
|
|
|
|
|
tokenizer = get_tokenizer()
|
|
|
|
|
|
|
|
|
|
# tokenize the text using the GPT-2 tokenizer
|
|
|
|
|
return tokenizer.encode(text)
|
|
|
|
|