wip: lore huggingface eval (#91)

pull/92/head
Laurel Orr 1 year ago committed by GitHub
parent 97f3ec557b
commit 5ad4b017b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,6 +19,8 @@ from transformers import (
GPTJForCausalLM,
GPTNeoForCausalLM,
GPTNeoXForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
OPTForCausalLM,
PreTrainedModel,
PreTrainedTokenizer,
@ -48,6 +50,7 @@ MODEL_REGISTRY = {
"bigscience/bloom-1b7": BloomForCausalLM,
"bigscience/bloom-3b": BloomForCausalLM,
"bigscience/bloom-7b1": BloomForCausalLM,
"chainyo/alpaca-lora-7b": LlamaForCausalLM,
"bigscience/bloom": AutoModelForCausalLM,
"bigscience/T0pp": AutoModelForSeq2SeqLM,
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
@ -65,6 +68,7 @@ MODEL_REGISTRY = {
MODEL_GENTYPE_REGISTRY = {
"text-generation": AutoModelForCausalLM,
"llama-text-generation": LlamaForCausalLM,
"text2text-generation": AutoModelForSeq2SeqLM,
}
@ -152,17 +156,20 @@ class GenerationPipeline:
return_tensors="pt",
)
encoded_prompt = encoded_prompt.to(self.device)
kwargs_to_pass = dict(
temperature=kwargs.get("temperature"),
top_k=kwargs.get("top_k"),
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
num_return_sequences=kwargs.get("num_return_sequences"),
)
kwargs_to_pass = {k: v for k, v in kwargs_to_pass.items() if v is not None}
output_dict = self.model.generate( # type: ignore
**encoded_prompt,
**kwargs_to_pass,
max_new_tokens=kwargs.get("max_new_tokens"),
temperature=kwargs.get("temperature", None),
top_k=kwargs.get("top_k", None),
top_p=kwargs.get("top_p", None),
repetition_penalty=kwargs.get("repetition_penalty", None),
do_sample=kwargs.get("do_sample", None) if not self.bitsandbytes else False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
num_return_sequences=kwargs.get("num_return_sequences", None),
output_scores=True,
return_dict_in_generate=True,
)
@ -458,17 +465,25 @@ class TextGenerationModel(HuggingFaceModel):
perc_max_gpu_mem_red,
use_fp16,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", padding_side="left"
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
truncation_side="left",
padding_side="left",
use_fast=False,
if (
MODEL_REGISTRY.get(
self.model_name, MODEL_GENTYPE_REGISTRY.get(self.model_type, None)
)
== LlamaForCausalLM
):
tokenizer = LlamaTokenizer.from_pretrained(self.model_name)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", padding_side="left"
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
truncation_side="left",
padding_side="left",
use_fast=False,
)
dtype = torch.float16 if use_fp16 else "auto"
if use_bitsandbytes:
print("WARNING!!! Cannot use sampling with bitsandbytes.")
@ -501,11 +516,10 @@ class TextGenerationModel(HuggingFaceModel):
)
model.eval()
print(f"Loaded Model DType {model.dtype}")
self.is_encdec = model.config.is_encoder_decoder
if not self.is_encdec:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
if not use_bitsandbytes:
if use_accelerate:
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)

@ -48,7 +48,8 @@ EXTRAS = {
"Flask>=2.1.2",
"sentence_transformers>=2.2.0",
"torch>=1.8.0",
"transformers>=4.20.0,<4.26.0",
"transformers>=4.29.0,<4.31.0",
"tokenizers>=0.13.3",
],
"app": [
"fastapi>=0.70.0",

Loading…
Cancel
Save