You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
2.4 KiB

"""Base class for all language models."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence
from pydantic import BaseModel
from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult, PromptValue, get_buffer_string
def _get_num_tokens_default_method(text: str) -> int:
"""Get the number of tokens present in the text."""
# TODO: this method may not be exact.
# TODO: this method may differ based on model (eg codex).
from transformers import GPT2TokenizerFast
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"This is needed in order to calculate get_num_tokens. "
"Please install it with `pip install transformers`."
# create a GPT-2 tokenizer instance
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
# tokenize the text using the GPT-2 tokenizer
tokenized_text = tokenizer.tokenize(text)
# calculate the number of tokens in the tokenized text
return len(tokenized_text)
class BaseLanguageModel(BaseModel, ABC):
def generate_prompt(
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
async def agenerate_prompt(
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
"""Predict text from text."""
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
"""Predict message from messages."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
return _get_num_tokens_default_method(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])