forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
"""Base interface for large language models to expose."""
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, List, Mapping, Optional
|
|
|
|
|
|
class LLM(ABC):
|
|
"""LLM wrapper should take in a prompt and return a string."""
|
|
|
|
def get_num_tokens(self, text: str) -> int:
|
|
"""Get the number of tokens present in the text."""
|
|
# TODO: this method may not be exact.
|
|
# TODO: this method may differ based on model (eg codex).
|
|
try:
|
|
from transformers import GPT2TokenizerFast
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import transformers python package. "
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
"Please it install it with `pip install transformers`."
|
|
)
|
|
# create a GPT-3 tokenizer instance
|
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
|
|
# tokenize the text using the GPT-3 tokenizer
|
|
tokenized_text = tokenizer.tokenize(text)
|
|
|
|
# calculate the number of tokens in the tokenized text
|
|
return len(tokenized_text)
|
|
|
|
@abstractmethod
|
|
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
|
"""Run the LLM on the given prompt and input."""
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {}
|
|
|
|
def __str__(self) -> str:
|
|
"""Get a string representation of the object for printing."""
|
|
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"
|
|
return f"{cls_name}\nParams: {self._identifying_params}"
|