|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
"""Huggingface model."""
|
|
|
|
|
import json
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
import numpy as np
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
@ -47,7 +48,11 @@ class Pipeline:
|
|
|
|
|
):
|
|
|
|
|
"""Initialize."""
|
|
|
|
|
self.model = model
|
|
|
|
|
self.max_length = model.config.max_position_embeddings
|
|
|
|
|
# Used for GPT
|
|
|
|
|
self.max_length = getattr(model.config, "max_position_embeddings", None)
|
|
|
|
|
if self.max_length is None:
|
|
|
|
|
# Used for T0
|
|
|
|
|
self.max_length = model.config.d_model
|
|
|
|
|
self.tokenizer = tokenizer
|
|
|
|
|
self.device = (
|
|
|
|
|
torch.device("cpu")
|
|
|
|
@ -68,7 +73,7 @@ class Pipeline:
|
|
|
|
|
# the user indicated generation tokens is preserved.
|
|
|
|
|
max_input_length = kwargs.get("max_input_length")
|
|
|
|
|
encoded_prompt = self.tokenizer.encode(
|
|
|
|
|
text, max_length=max_input_length, return_tensors="pt"
|
|
|
|
|
text, max_length=max_input_length, truncation=True, return_tensors="pt"
|
|
|
|
|
)
|
|
|
|
|
encoded_prompt = encoded_prompt.to(self.device)
|
|
|
|
|
output_sequences = self.model.generate( # type: ignore
|
|
|
|
@ -104,7 +109,7 @@ class HuggingFaceModel(Model):
|
|
|
|
|
device: int,
|
|
|
|
|
use_accelerate: bool,
|
|
|
|
|
perc_max_gpu_mem_red: float,
|
|
|
|
|
use_fp32: bool,
|
|
|
|
|
use_fp16: bool,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Initialize model.
|
|
|
|
@ -117,7 +122,7 @@ class HuggingFaceModel(Model):
|
|
|
|
|
device: device to use for model.
|
|
|
|
|
use_accelerate: whether to use accelerate for multi-gpu inference.
|
|
|
|
|
perc_max_gpu_mem_red: percent max memory reduction in accelerate
|
|
|
|
|
use_fp32: use fp32 for model weights.
|
|
|
|
|
use_fp16: use fp16 for model weights.
|
|
|
|
|
"""
|
|
|
|
|
# Check if providing path
|
|
|
|
|
self.model_path = model_name
|
|
|
|
@ -133,16 +138,29 @@ class HuggingFaceModel(Model):
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
except ValueError:
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
|
|
|
|
dtype = torch.float32 if use_fp32 else torch.float16
|
|
|
|
|
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
|
|
|
|
|
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
|
|
|
|
|
)
|
|
|
|
|
dtype = torch.float16 if use_fp16 else "auto"
|
|
|
|
|
try:
|
|
|
|
|
# Try to explicitely find a fp16 copy (gpt-j-6B for example)
|
|
|
|
|
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
|
|
|
|
|
self.model_path, cache_dir=cache_dir, revision="float16", torch_dtype=torch.float16
|
|
|
|
|
)
|
|
|
|
|
except Exception:
|
|
|
|
|
model = MODEL_REGISTRY[model_name].from_pretrained( # type: ignore
|
|
|
|
|
self.model_path, cache_dir=cache_dir, torch_dtype=dtype
|
|
|
|
|
)
|
|
|
|
|
model.eval()
|
|
|
|
|
print(f"Loaded Model DType {model.dtype}")
|
|
|
|
|
if use_accelerate:
|
|
|
|
|
self._dispatch_accelerate_model(model, perc_max_gpu_mem_red)
|
|
|
|
|
device = 0
|
|
|
|
|
else:
|
|
|
|
|
if device > -1:
|
|
|
|
|
model = model.to(device) # type: ignore
|
|
|
|
|
torch_device = (
|
|
|
|
|
torch.device("cpu")
|
|
|
|
|
if (device == -1 or not torch.cuda.is_available())
|
|
|
|
|
else torch.device(f"cuda:{device}")
|
|
|
|
|
)
|
|
|
|
|
model = model.to(torch_device) # type: ignore
|
|
|
|
|
self.pipeline = Pipeline( # type: ignore
|
|
|
|
|
model=model, tokenizer=tokenizer, device=device
|
|
|
|
|
)
|
|
|
|
@ -228,11 +246,11 @@ class HuggingFaceModel(Model):
|
|
|
|
|
max_input_len = self.pipeline.max_length - kwargs.get("max_tokens")
|
|
|
|
|
# Add tokens for length
|
|
|
|
|
encoded_prompt_with_special = self.pipeline.tokenizer.encode(
|
|
|
|
|
prompt, max_length=max_input_len
|
|
|
|
|
prompt, max_length=max_input_len, truncation=True
|
|
|
|
|
)
|
|
|
|
|
# Remove tokens as the pipeline removes special tokens upon return
|
|
|
|
|
encoded_prompt_without_special = self.pipeline.tokenizer.encode(
|
|
|
|
|
prompt, max_length=max_input_len, add_special_tokens=False
|
|
|
|
|
prompt, max_length=max_input_len, truncation=True, add_special_tokens=False
|
|
|
|
|
)
|
|
|
|
|
result = self.pipeline(
|
|
|
|
|
prompt,
|
|
|
|
@ -256,3 +274,43 @@ class HuggingFaceModel(Model):
|
|
|
|
|
else:
|
|
|
|
|
final_results = [r["generated_text"][start_idx:] for r in result]
|
|
|
|
|
return final_results
|
|
|
|
|
|
|
|
|
|
def logits_scoring(self, prompt: str, gold_choices: List[str], **kwargs: Any) -> List[str]:
|
|
|
|
|
"""
|
|
|
|
|
Given the prompt and gold choices, choose the best choice with max logits.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: promt to generate from.
|
|
|
|
|
gold_choices: list of choices to choose from.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
the returned gold choice
|
|
|
|
|
"""
|
|
|
|
|
max_input_len = self.pipeline.max_length
|
|
|
|
|
|
|
|
|
|
# Encode prompt and choices
|
|
|
|
|
encoded_prompt = self.pipeline.tokenizer.encode(
|
|
|
|
|
prompt, max_length=max_input_len, truncation=True, return_tensors="pt"
|
|
|
|
|
).to(self.pipeline.device)
|
|
|
|
|
encoded_choices = [
|
|
|
|
|
self.pipeline.tokenizer.encode(
|
|
|
|
|
choice, max_length=max_input_len, truncation=True, return_tensors="pt", add_special_tokens=False
|
|
|
|
|
).to(self.pipeline.device)
|
|
|
|
|
for choice in gold_choices
|
|
|
|
|
]
|
|
|
|
|
# Just feed on choice to model
|
|
|
|
|
results = [
|
|
|
|
|
self.pipeline.model(input_ids=encoded_prompt, labels=encoded_choice).logits
|
|
|
|
|
for encoded_choice in encoded_choices
|
|
|
|
|
]
|
|
|
|
|
# Choose choice with max logits sum (sum log prob)
|
|
|
|
|
logit_sum = []
|
|
|
|
|
for res, encoded_choice in zip(results, encoded_choices):
|
|
|
|
|
logits = res.cpu().detach().numpy()[0]
|
|
|
|
|
choice_index = encoded_choice.cpu().detach().numpy()[0]
|
|
|
|
|
# print(logits, choice_index)
|
|
|
|
|
logits_to_sum = [logits[i][choice_index[i]] for i in range(len(choice_index))]
|
|
|
|
|
# print(logits_to_sum)
|
|
|
|
|
logit_sum.append(sum(logits_to_sum))
|
|
|
|
|
# Return choice with max logits
|
|
|
|
|
return gold_choices[np.argmax(logit_sum)]
|
|
|
|
|