Add 'get_token_ids' method (#4784)

Let user inspect the token ids in addition to getting th enumber of tokens

---------

Co-authored-by: Zach Schillaci <40636930+zachschillaci27@users.noreply.github.com>
searx_updates
Zander Chase 1 year ago committed by GitHub
parent ef7d015be5
commit 785502edb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,8 +10,8 @@ from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
def _get_num_tokens_default_method(text: str) -> int: def _get_token_ids_default_method(text: str) -> List[int]:
"""Get the number of tokens present in the text.""" """Encode the text into token IDs."""
# TODO: this method may not be exact. # TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex). # TODO: this method may differ based on model (eg codex).
try: try:
@ -19,17 +19,14 @@ def _get_num_tokens_default_method(text: str) -> int:
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import transformers python package. " "Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_token_ids. "
"Please install it with `pip install transformers`." "Please install it with `pip install transformers`."
) )
# create a GPT-2 tokenizer instance # create a GPT-2 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-2 tokenizer # tokenize the text using the GPT-2 tokenizer
tokenized_text = tokenizer.tokenize(text) return tokenizer.encode(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseLanguageModel(BaseModel, ABC): class BaseLanguageModel(BaseModel, ABC):
@ -61,9 +58,13 @@ class BaseLanguageModel(BaseModel, ABC):
) -> BaseMessage: ) -> BaseMessage:
"""Predict message from messages.""" """Predict message from messages."""
def get_token_ids(self, text: str) -> List[int]:
"""Get the token present in the text."""
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int: def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.""" """Get the number of tokens present in the text."""
return _get_num_tokens_default_method(text) return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message.""" """Get the number of tokens in the message."""

@ -3,7 +3,17 @@ from __future__ import annotations
import logging import logging
import sys import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
)
from pydantic import Extra, Field, root_validator from pydantic import Extra, Field, root_validator
from tenacity import ( from tenacity import (
@ -30,9 +40,24 @@ from langchain.schema import (
) )
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _import_tiktoken() -> Any:
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_token_ids. "
"Please install it with `pip install tiktoken`."
)
return tiktoken
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
import openai import openai
@ -354,42 +379,8 @@ class ChatOpenAI(BaseChatModel):
"""Return type of chat model.""" """Return type of chat model."""
return "openai-chat" return "openai-chat"
def get_num_tokens(self, text: str) -> int: def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
"""Calculate num tokens with tiktoken package.""" tiktoken_ = _import_tiktoken()
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
return super().get_num_tokens(text)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`."
)
# create a GPT-3.5-Turbo encoder instance
enc = tiktoken.encoding_for_model(self.model_name)
# encode the text using the GPT-3.5-Turbo encoder
tokenized_text = enc.encode(text)
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`."
)
model = self.model_name model = self.model_name
if model == "gpt-3.5-turbo": if model == "gpt-3.5-turbo":
# gpt-3.5-turbo may change over time. # gpt-3.5-turbo may change over time.
@ -399,14 +390,31 @@ class ChatOpenAI(BaseChatModel):
# gpt-4 may change over time. # gpt-4 may change over time.
# Returning num tokens assuming gpt-4-0314. # Returning num tokens assuming gpt-4-0314.
model = "gpt-4-0314" model = "gpt-4-0314"
# Returns the number of tokens used by a list of messages. # Returns the number of tokens used by a list of messages.
try: try:
encoding = tiktoken.encoding_for_model(model) encoding = tiktoken_.encoding_for_model(model)
except KeyError: except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.") logger.warning("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base") model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_token_ids(self, text: str) -> List[int]:
"""Get the tokens present in the text with tiktoken package."""
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
return super().get_token_ids(text)
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
if model == "gpt-3.5-turbo-0301": if model == "gpt-3.5-turbo-0301":
# every message follows <im_start>{role/name}\n{content}<im_end>\n # every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4 tokens_per_message = 4

@ -454,8 +454,8 @@ class BaseOpenAI(BaseLLM):
"""Return type of llm.""" """Return type of llm."""
return "openai" return "openai"
def get_num_tokens(self, text: str) -> int: def get_token_ids(self, text: str) -> List[int]:
"""Calculate num tokens with tiktoken package.""" """Get the token IDs using the tiktoken package."""
# tiktoken NOT supported for Python < 3.8 # tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8: if sys.version_info[1] < 8:
return super().get_num_tokens(text) return super().get_num_tokens(text)
@ -470,15 +470,12 @@ class BaseOpenAI(BaseLLM):
enc = tiktoken.encoding_for_model(self.model_name) enc = tiktoken.encoding_for_model(self.model_name)
tokenized_text = enc.encode( return enc.encode(
text, text,
allowed_special=self.allowed_special, allowed_special=self.allowed_special,
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
# calculate the number of tokens in the encoded text
return len(tokenized_text)
def modelname_to_contextsize(self, modelname: str) -> int: def modelname_to_contextsize(self, modelname: str) -> int:
"""Calculate the maximum number of tokens possible to generate for a model. """Calculate the maximum number of tokens possible to generate for a model.
@ -802,11 +799,11 @@ class OpenAIChat(BaseLLM):
"""Return type of llm.""" """Return type of llm."""
return "openai-chat" return "openai-chat"
def get_num_tokens(self, text: str) -> int: def get_token_ids(self, text: str) -> List[int]:
"""Calculate num tokens with tiktoken package.""" """Get the token IDs using the tiktoken package."""
# tiktoken NOT supported for Python < 3.8 # tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8: if sys.version_info[1] < 8:
return super().get_num_tokens(text) return super().get_token_ids(text)
try: try:
import tiktoken import tiktoken
except ImportError: except ImportError:
@ -815,15 +812,10 @@ class OpenAIChat(BaseLLM):
"This is needed in order to calculate get_num_tokens. " "This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install tiktoken`." "Please install it with `pip install tiktoken`."
) )
# create a GPT-3.5-Turbo encoder instance
enc = tiktoken.encoding_for_model("gpt-3.5-turbo")
# encode the text using the GPT-3.5-Turbo encoder enc = tiktoken.encoding_for_model(self.model_name)
tokenized_text = enc.encode( return enc.encode(
text, text,
allowed_special=self.allowed_special, allowed_special=self.allowed_special,
disallowed_special=self.disallowed_special, disallowed_special=self.disallowed_special,
) )
# calculate the number of tokens in the encoded text
return len(tokenized_text)

@ -6,6 +6,7 @@ from typing import Generator
import pytest import pytest
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI, OpenAIChat from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult from langchain.schema import LLMResult
@ -237,3 +238,40 @@ def test_openai_modelname_to_contextsize_invalid() -> None:
"""Test model name to context size on an invalid model.""" """Test model name to context size on an invalid model."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
OpenAI().modelname_to_contextsize("foobar") OpenAI().modelname_to_contextsize("foobar")
_EXPECTED_NUM_TOKENS = {
"ada": 17,
"babbage": 17,
"curie": 17,
"davinci": 17,
"gpt-4": 12,
"gpt-4-32k": 12,
"gpt-3.5-turbo": 12,
}
_MODELS = models = [
"ada",
"babbage",
"curie",
"davinci",
]
_CHAT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
]
@pytest.mark.parametrize("model", _MODELS)
def test_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = OpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]
@pytest.mark.parametrize("model", _CHAT_MODELS)
def test_chat_openai_get_num_tokens(model: str) -> None:
"""Test get_tokens."""
llm = ChatOpenAI(model=model)
assert llm.get_num_tokens("表情符号是\n🦜🔗") == _EXPECTED_NUM_TOKENS[model]

@ -1,15 +1,19 @@
"""Test formatting functionality.""" """Test formatting functionality."""
from langchain.base_language import _get_num_tokens_default_method from langchain.base_language import _get_token_ids_default_method
class TestTokenCountingWithGPT2Tokenizer: class TestTokenCountingWithGPT2Tokenizer:
def test_tokenization(self) -> None:
# Check that the tokenization is consistent with the GPT-2 tokenizer
assert _get_token_ids_default_method("This is a test") == [1212, 318, 257, 1332]
def test_empty_token(self) -> None: def test_empty_token(self) -> None:
assert _get_num_tokens_default_method("") == 0 assert len(_get_token_ids_default_method("")) == 0
def test_multiple_tokens(self) -> None: def test_multiple_tokens(self) -> None:
assert _get_num_tokens_default_method("a b c") == 3 assert len(_get_token_ids_default_method("a b c")) == 3
def test_special_tokens(self) -> None: def test_special_tokens(self) -> None:
# test for consistency when the default tokenizer is changed # test for consistency when the default tokenizer is changed
assert _get_num_tokens_default_method("a:b_c d") == 6 assert len(_get_token_ids_default_method("a:b_c d")) == 6

Loading…
Cancel
Save