|
|
|
@ -19,6 +19,8 @@ from transformers import (
|
|
|
|
|
GPTJForCausalLM,
|
|
|
|
|
GPTNeoForCausalLM,
|
|
|
|
|
GPTNeoXForCausalLM,
|
|
|
|
|
LlamaForCausalLM,
|
|
|
|
|
LlamaTokenizer,
|
|
|
|
|
OPTForCausalLM,
|
|
|
|
|
PreTrainedModel,
|
|
|
|
|
PreTrainedTokenizer,
|
|
|
|
@ -48,6 +50,7 @@ MODEL_REGISTRY = {
|
|
|
|
|
"bigscience/bloom-1b7": BloomForCausalLM,
|
|
|
|
|
"bigscience/bloom-3b": BloomForCausalLM,
|
|
|
|
|
"bigscience/bloom-7b1": BloomForCausalLM,
|
|
|
|
|
"chainyo/alpaca-lora-7b": LlamaForCausalLM,
|
|
|
|
|
"bigscience/bloom": AutoModelForCausalLM,
|
|
|
|
|
"bigscience/T0pp": AutoModelForSeq2SeqLM,
|
|
|
|
|
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
|
|
|
|
@ -65,6 +68,7 @@ MODEL_REGISTRY = {
|
|
|
|
|
|
|
|
|
|
MODEL_GENTYPE_REGISTRY = {
|
|
|
|
|
"text-generation": AutoModelForCausalLM,
|
|
|
|
|
"llama-text-generation": LlamaForCausalLM,
|
|
|
|
|
"text2text-generation": AutoModelForSeq2SeqLM,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -152,17 +156,20 @@ class GenerationPipeline:
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
)
|
|
|
|
|
encoded_prompt = encoded_prompt.to(self.device)
|
|
|
|
|
kwargs_to_pass = dict(
|
|
|
|
|
temperature=kwargs.get("temperature"),
|
|
|
|
|
top_k=kwargs.get("top_k"),
|
|
|
|
|
top_p=kwargs.get("top_p"),
|
|
|
|
|
repetition_penalty=kwargs.get("repetition_penalty"),
|
|
|
|
|
num_return_sequences=kwargs.get("num_return_sequences"),
|
|
|
|
|
)
|
|
|
|
|
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
|
|
|
|
|
output_dict = self.model.generate( # type: ignore
|
|
|
|
|
**encoded_prompt,
|
|
|
|
|
**kwargs_to_pass,
|
|
|
|
|
max_new_tokens=kwargs.get("max_new_tokens"),
|
|
|
|
|
temperature=kwargs.get("temperature", None),
|
|
|
|
|
top_k=kwargs.get("top_k", None),
|
|
|
|
|
top_p=kwargs.get("top_p", None),
|
|
|
|
|
repetition_penalty=kwargs.get("repetition_penalty", None),
|
|
|
|
|
do_sample=kwargs.get("do_sample", None) if not self.bitsandbytes else False,
|
|
|
|
|
eos_token_id=self.tokenizer.eos_token_id,
|
|
|
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
|
|
|
num_return_sequences=kwargs.get("num_return_sequences", None),
|
|
|
|
|
output_scores=True,
|
|
|
|
|
return_dict_in_generate=True,
|
|
|
|
|
)
|
|
|
|
@ -458,17 +465,25 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
perc_max_gpu_mem_red,
|
|
|
|
|
use_fp16,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
self.model_name, truncation_side="left", padding_side="left"
|
|
|
|
|
)
|
|
|
|
|
except ValueError:
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
self.model_name,
|
|
|
|
|
truncation_side="left",
|
|
|
|
|
padding_side="left",
|
|
|
|
|
use_fast=False,
|
|
|
|
|
if (
|
|
|
|
|
MODEL_REGISTRY.get(
|
|
|
|
|
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
|
|
|
|
|
)
|
|
|
|
|
== LlamaForCausalLM
|
|
|
|
|
):
|
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(self.model_name)
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
self.model_name, truncation_side="left", padding_side="left"
|
|
|
|
|
)
|
|
|
|
|
except ValueError:
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
self.model_name,
|
|
|
|
|
truncation_side="left",
|
|
|
|
|
padding_side="left",
|
|
|
|
|
use_fast=False,
|
|
|
|
|
)
|
|
|
|
|
dtype = torch.float16 if use_fp16 else "auto"
|
|
|
|
|
if use_bitsandbytes:
|
|
|
|
|
print("WARNING!!! Cannot use sampling with bitsandbytes.")
|
|
|
|
@ -501,11 +516,10 @@ class TextGenerationModel(HuggingFaceModel):
|
|
|
|
|
)
|
|
|
|
|
model.eval()
|
|
|
|
|
print(f"Loaded Model DType {model.dtype}")
|
|
|
|
|
|
|
|
|
|
self.is_encdec = model.config.is_encoder_decoder
|
|
|
|
|
if not self.is_encdec:
|
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
if not use_bitsandbytes:
|
|
|
|
|
if use_accelerate:
|
|
|
|
|
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
|
|
|
|
|