diff --git a/langchain/llms/base.py b/langchain/llms/base.py index df81242f63..60728841be 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -6,6 +6,27 @@ from typing import Any, List, Mapping, Optional class LLM(ABC): """LLM wrapper should take in a prompt and return a string.""" + def get_num_tokens(self, text: str) -> int: + """Get the number of tokens present in the text.""" + # TODO: this method may not be exact. + # TODO: this method may differ based on model (eg codex). + try: + from transformers import GPT2TokenizerFast + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "This is needed in order to calculate max_tokens_for_prompt. " + "Please it install it with `pip install transformers`." + ) + # create a GPT-3 tokenizer instance + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + + # tokenize the text using the GPT-3 tokenizer + tokenized_text = tokenizer.tokenize(text) + + # calculate the number of tokens in the tokenized text + return len(tokenized_text) + @abstractmethod def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str: """Run the LLM on the given prompt and input.""" diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index 2ea83ee77c..9de56f0ce8 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -179,24 +179,7 @@ class OpenAI(LLM, BaseModel): max_tokens = openai.max_token_for_prompt("Tell me a joke.") """ - # TODO: this method may not be exact. - # TODO: this method may differ based on model (eg codex). - try: - from transformers import GPT2TokenizerFast - except ImportError: - raise ValueError( - "Could not import transformers python package. " - "This is needed in order to calculate max_tokens_for_prompt. " - "Please it install it with `pip install transformers`." - ) - # create a GPT-3 tokenizer instance - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - - # tokenize the text using the GPT-3 tokenizer - tokenized_text = tokenizer.tokenize(prompt) - - # calculate the number of tokens in the tokenized text - num_tokens = len(tokenized_text) + num_tokens = self.get_num_tokens(prompt) # get max context size for model by name max_size = self.modelname_to_contextsize(self.model_name)