mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
expose get_num_tokens method (#327)
This commit is contained in:
parent
8fdcdf4c2f
commit
8861770bd0
@ -6,6 +6,27 @@ from typing import Any, List, Mapping, Optional
|
||||
class LLM(ABC):
|
||||
"""LLM wrapper should take in a prompt and return a string."""
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text."""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
return len(tokenized_text)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
@ -179,24 +179,7 @@ class OpenAI(LLM, BaseModel):
|
||||
|
||||
max_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
||||
"""
|
||||
# TODO: this method may not be exact.
|
||||
# TODO: this method may differ based on model (eg codex).
|
||||
try:
|
||||
from transformers import GPT2TokenizerFast
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
# create a GPT-3 tokenizer instance
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
|
||||
# tokenize the text using the GPT-3 tokenizer
|
||||
tokenized_text = tokenizer.tokenize(prompt)
|
||||
|
||||
# calculate the number of tokens in the tokenized text
|
||||
num_tokens = len(tokenized_text)
|
||||
num_tokens = self.get_num_tokens(prompt)
|
||||
|
||||
# get max context size for model by name
|
||||
max_size = self.modelname_to_contextsize(self.model_name)
|
||||
|
Loading…
Reference in New Issue
Block a user