|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
"""Huggingface model."""
|
|
|
|
|
import json
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
|
|
|
@ -9,11 +10,56 @@ from transformers import (
|
|
|
|
|
GPT2LMHeadModel,
|
|
|
|
|
GPTJForCausalLM,
|
|
|
|
|
GPTNeoForCausalLM,
|
|
|
|
|
PreTrainedModel,
|
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
|
pipeline,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from manifest.api.models.model import Model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPT2Pipeline:
|
|
|
|
|
"""Custom GPT3 Pipeline."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, device: int
|
|
|
|
|
):
|
|
|
|
|
"""Initialize."""
|
|
|
|
|
self.model = model
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.device = device
|
|
|
|
|
self.model.to(self.device) # type: ignore
|
|
|
|
|
|
|
|
|
|
def __call__(self, text: str, **kwargs: Any) -> List[Dict[str, str]]:
|
|
|
|
|
"""Generate from text.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
text: text to generate.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
generated text.
|
|
|
|
|
"""
|
|
|
|
|
encoded_prompt = self.tokenizer.encode(text, return_tensors="pt")
|
|
|
|
|
encoded_prompt = encoded_prompt.to(self.device)
|
|
|
|
|
output_sequences = self.model.generate( # type: ignore
|
|
|
|
|
encoded_prompt,
|
|
|
|
|
max_length=kwargs.get("max_length"),
|
|
|
|
|
temperature=kwargs.get("temperature"),
|
|
|
|
|
top_k=kwargs.get("top_k"),
|
|
|
|
|
top_p=kwargs.get("top_p"),
|
|
|
|
|
repetition_penalty=kwargs.get("repetition_penalty"),
|
|
|
|
|
do_sample=kwargs.get("do_sample"),
|
|
|
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
|
|
num_return_sequences=kwargs.get("num_return_sequences"),
|
|
|
|
|
)
|
|
|
|
|
generated_sequences = [
|
|
|
|
|
{"generated_text": self.tokenizer.decode(output_seq)}
|
|
|
|
|
for output_seq in output_sequences
|
|
|
|
|
]
|
|
|
|
|
return generated_sequences
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_REGISTRY = {
|
|
|
|
|
"EleutherAI/gpt-j-6B": GPTJForCausalLM,
|
|
|
|
|
"EleutherAI/gpt-neo-125M": GPTNeoForCausalLM,
|
|
|
|
@ -25,13 +71,13 @@ MODEL_REGISTRY = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MODEL_PIPELINE = {
|
|
|
|
|
"EleutherAI/gpt-j-6B": "text-generation",
|
|
|
|
|
"EleutherAI/gpt-neo-125M": "text-generation",
|
|
|
|
|
"EleutherAI/gpt-neo-1.3B": "text-generation",
|
|
|
|
|
"EleutherAI/gpt-neo-2.7B": "text-generation",
|
|
|
|
|
"gpt2": "text-generation",
|
|
|
|
|
"bigscience/T0pp": "text2text-generation",
|
|
|
|
|
"bigscience/T0_3B": "text2text-generation",
|
|
|
|
|
"EleutherAI/gpt-j-6B": GPT2Pipeline,
|
|
|
|
|
"EleutherAI/gpt-neo-125M": GPT2Pipeline,
|
|
|
|
|
"EleutherAI/gpt-neo-1.3B": GPT2Pipeline,
|
|
|
|
|
"EleutherAI/gpt-neo-2.7B": GPT2Pipeline,
|
|
|
|
|
"gpt2": GPT2Pipeline,
|
|
|
|
|
"bigscience/T0pp": partial(pipeline, "text2text-generation"),
|
|
|
|
|
"bigscience/T0_3B": partial(pipeline, "text2text-generation"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -56,14 +102,14 @@ class HuggingFaceModel(Model):
|
|
|
|
|
model_name = config["_name_or_path"]
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
print("Model Name:", self.model_name, "Model Path:", self.model_path)
|
|
|
|
|
model = MODEL_REGISTRY[model_name].from_pretrained(
|
|
|
|
|
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
|
|
|
|
|
self.model_path, cache_dir=cache_dir
|
|
|
|
|
)
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
self.pipeline = pipeline(
|
|
|
|
|
MODEL_PIPELINE[model_name], model=model, tokenizer=tokenizer, device=device
|
|
|
|
|
self.pipeline = MODEL_PIPELINE[model_name](
|
|
|
|
|
model=model, tokenizer=tokenizer, device=device
|
|
|
|
|
)
|
|
|
|
|
self.returns_input = MODEL_PIPELINE[model_name] == "text-generation"
|
|
|
|
|
self.returns_input = "gpt" in model_name
|
|
|
|
|
|
|
|
|
|
def get_init_params(self) -> Dict:
|
|
|
|
|
"""Return init params to determine what model is being used."""
|
|
|
|
@ -93,6 +139,7 @@ class HuggingFaceModel(Model):
|
|
|
|
|
repetition_penalty=kwargs.get("repetition_penalty"),
|
|
|
|
|
top_k=kwargs.get("top_k"),
|
|
|
|
|
top_p=kwargs.get("top_p"),
|
|
|
|
|
do_sample=kwargs.get("do_sample"),
|
|
|
|
|
num_return_sequences=num_return,
|
|
|
|
|
)
|
|
|
|
|
if self.returns_input:
|
|
|
|
|