|
|
@ -6,6 +6,27 @@ from typing import Any, List, Mapping, Optional
|
|
|
|
class LLM(ABC):
|
|
|
|
class LLM(ABC):
|
|
|
|
"""LLM wrapper should take in a prompt and return a string."""
|
|
|
|
"""LLM wrapper should take in a prompt and return a string."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
|
|
|
|
|
|
"""Get the number of tokens present in the text."""
|
|
|
|
|
|
|
|
# TODO: this method may not be exact.
|
|
|
|
|
|
|
|
# TODO: this method may differ based on model (eg codex).
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
from transformers import GPT2TokenizerFast
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"Could not import transformers python package. "
|
|
|
|
|
|
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
|
|
|
|
|
|
"Please it install it with `pip install transformers`."
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# create a GPT-3 tokenizer instance
|
|
|
|
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# tokenize the text using the GPT-3 tokenizer
|
|
|
|
|
|
|
|
tokenized_text = tokenizer.tokenize(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# calculate the number of tokens in the tokenized text
|
|
|
|
|
|
|
|
return len(tokenized_text)
|
|
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
@abstractmethod
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|