diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index d189296b..37834ce6 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -139,6 +139,7 @@ class TextSplitter(ABC): def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", + model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -153,8 +154,10 @@ class TextSplitter(ABC): "Please install it with `pip install tiktoken`." ) - # create a GPT-3 encoder instance - enc = tiktoken.get_encoding(encoding_name) + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str, **kwargs: Any) -> int: return len( @@ -193,6 +196,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, encoding_name: str = "gpt2", + model_name: Optional[str] = None, allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, @@ -207,8 +211,12 @@ class TokenTextSplitter(TextSplitter): "This is needed in order to for TokenTextSplitter. " "Please install it with `pip install tiktoken`." ) - # create a GPT-3 encoder instance - self._tokenizer = tiktoken.get_encoding(encoding_name) + + if model_name is not None: + enc = tiktoken.encoding_for_model(model_name) + else: + enc = tiktoken.get_encoding(encoding_name) + self._tokenizer = enc self._allowed_special = allowed_special self._disallowed_special = disallowed_special