expose get_num_tokens method (#327)

This commit is contained in:
Harrison Chase 2022-12-13 05:22:42 -08:00 committed by GitHub
parent 8fdcdf4c2f
commit 8861770bd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 18 deletions

View File

@ -6,6 +6,27 @@ from typing import Any, List, Mapping, Optional
class LLM(ABC): class LLM(ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
@abstractmethod @abstractmethod
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""

View File

@ -179,24 +179,7 @@ class OpenAI(LLM, BaseModel):
max_tokens = openai.max_token_for_prompt("Tell me a joke.") max_tokens = openai.max_token_for_prompt("Tell me a joke.")
""" """
# TODO: this method may not be exact. num_tokens = self.get_num_tokens(prompt)
# TODO: this method may differ based on model (eg codex).
try:
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please it install it with `pip install transformers`."
)
# create a GPT-3 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-3 tokenizer
tokenized_text = tokenizer.tokenize(prompt)
# calculate the number of tokens in the tokenized text
num_tokens = len(tokenized_text)
# get max context size for model by name # get max context size for model by name
max_size = self.modelname_to_contextsize(self.model_name) max_size = self.modelname_to_contextsize(self.model_name)