Deepspeed (#38)

* [feat] deepspeed and batch support huggingface api

* [chore] add deepspeed to readme

* [chore] fix hf api test
laurel/helm
Laurel Orr 2 years ago committed by GitHub
parent 8b423d6962
commit a4cd201b8e

@ -1,10 +1,13 @@
dev:
dev: deepspeed
pip install -e .[all]
pre-commit install
test: dev check
pytest tests
deepspeed:
pip install -e git+https://github.com/microsoft/DeepSpeed.git#egg=deepspeed
format:
isort --atomic manifest/ tests/
black manifest/ tests/

@ -166,7 +166,7 @@ result = manifest.run(prompt, "Laurel", max_tokens=50)
```
# Local Huggingface Models
To use a HuggingFace generative model, in `manifest/api` we have a Falsk application that hosts the models for you.
To use a HuggingFace generative model, in `manifest/api` we have a Flask application that hosts the models for you.
In a separate terminal or Tmux/Screen session, to load 6B parameters models, run
```bash
@ -186,7 +186,7 @@ manifest = Manifest(
If you have a custom model you trained, pass the model path to `--model_name_or_path`.
To help load larger models, we also support using `parallelize()` from HF, [accelerate](https://huggingface.co/docs/accelerate/index), and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). You will need to install these packages first. We list the commands to load larger models below.
To help load larger models, we also support using `parallelize()` from HF, [accelerate](https://huggingface.co/docs/accelerate/index), [bitsandbytes](https://github.com/TimDettmers/bitsandbytes), and [deepspeed](https://github.com/microsoft/DeepSpeed). You will need to install these packages first via `pip install manifest-ml[api]`. We list the commands to load larger models below.
* T0pp
```bash

@ -56,7 +56,7 @@ def parse_args() -> argparse.Namespace:
"--cache_dir", default=None, type=str, help="Cache directory for models."
)
parser.add_argument(
"--device", type=int, default=-1, help="Model device. -1 for CPU."
"--device", type=int, default=0, help="Model device. -1 for CPU."
)
parser.add_argument(
"--fp16", action="store_true", help="Force use fp16 for model params."
@ -88,6 +88,11 @@ def parse_args() -> argparse.Namespace:
"This will override --device parameter."
),
)
parser.add_argument(
"--use_deepspeed",
action="store_true",
help=("Use deepspeed. This will override --device parameter."),
)
args = parser.parse_args()
return args
@ -109,14 +114,28 @@ def main() -> None:
model_config = kwargs.model_config
if not model_name_or_path and not model_config:
raise ValueError("Must provide model_name_or_path or model_config.")
use_accelerate = kwargs.use_accelerate_multigpu
if use_accelerate:
if kwargs.use_accelerate_multigpu:
logger.info("Using accelerate. Overridding --device argument.")
if (
kwargs.percent_max_gpu_mem_reduction <= 0
or kwargs.percent_max_gpu_mem_reduction > 1
):
raise ValueError("percent_max_gpu_mem_reduction must be in (0, 1].")
if (
sum(
[
kwargs.use_accelerate_multigpu,
kwargs.use_hf_parallelize,
kwargs.use_bitsandbytes,
kwargs.use_deepspeed,
]
)
> 1
):
raise ValueError(
"Only one of use_accelerate_multigpu, use_hf_parallelize, "
"use_bitsandbytes, and use_deepspeed can be set."
)
# Global model
global model
model = MODEL_CONSTRUCTORS[model_type](
@ -124,9 +143,10 @@ def main() -> None:
model_config=model_config,
cache_dir=kwargs.cache_dir,
device=kwargs.device,
use_accelerate=use_accelerate,
use_accelerate=kwargs.use_accelerate_multigpu,
use_parallelize=kwargs.use_hf_parallelize,
use_bitsandbytes=kwargs.use_bitsandbytes,
use_deepspeed=kwargs.use_deepspeed,
perc_max_gpu_mem_red=kwargs.percent_max_gpu_mem_reduction,
use_fp16=kwargs.fp16,
)
@ -166,8 +186,8 @@ def choice_logits() -> Dict:
if not isinstance(gold_choices, list):
raise ValueError("Gold choices must be a list of string choices")
result, score = model.logits_scoring(prompt, gold_choices, **generation_args)
results = [{"text": result, "text_logprob": score}]
choice_score_list = model.logits_scoring(prompt, gold_choices, **generation_args)
results = [{"text": r[0], "text_logprob": r[1]} for r in choice_score_list]
# transform the result into the openai format
return Response(results, response_type="choice_selection").__dict__()

@ -1,9 +1,11 @@
"""Huggingface model."""
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Tuple, Union, cast
import torch
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils.modeling import get_max_memory as acc_get_max_memory
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
@ -18,6 +20,7 @@ from transformers import (
PreTrainedTokenizer,
)
import deepspeed
from manifest.api.models.model import Model
MODEL_REGISTRY = {
@ -43,6 +46,7 @@ MODEL_REGISTRY = {
"bigscience/bloom": AutoModelForCausalLM,
"bigscience/T0pp": AutoModelForSeq2SeqLM,
"bigscience/T0_3B": AutoModelForSeq2SeqLM,
"google/t5-small-lm-adapt": AutoModelForSeq2SeqLM, # 220M
"google/t5-l-lm-adapt": AutoModelForSeq2SeqLM, # 800M
"google/t5-xl-lm-adapt": AutoModelForSeq2SeqLM, # 3B
"google/t5-xxl-lm-adapt": AutoModelForSeq2SeqLM, # 11B
@ -75,16 +79,18 @@ class Pipeline:
def __init__(
self,
model: PreTrainedModel,
model: Union[PreTrainedModel, deepspeed.InferenceEngine],
tokenizer: PreTrainedTokenizer,
device: int = None,
bitsandbytes: bool = False,
is_encdec: bool = False,
):
"""Initialize."""
# Use to turn off sampling
# https://github.com/TimDettmers/bitsandbytes/issues/42
self.bitsandbytes = bitsandbytes
self.model = model
self.is_encdec = is_encdec
config = model.config # type: ignore
# Used for GPT
self.max_length = getattr(config, "max_position_embeddings", None)
@ -109,10 +115,9 @@ class Pipeline:
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
print("HERE", self.device)
def __call__(
self, text: str, **kwargs: Any
self, text: Union[str, List[str]], **kwargs: Any
) -> List[Dict[str, Union[str, List[float]]]]:
"""Generate from text.
@ -124,22 +129,30 @@ class Pipeline:
"""
# If text is longer than max model length, we reduce max input length to ensure
# the user indicated generation tokens is preserved.
max_input_length = kwargs.get("max_input_length")
max_input_len = (
self.max_length - kwargs.get("max_new_tokens")
if not self.is_encdec
else self.max_length
)
encoded_prompt = self.tokenizer(
text, max_length=max_input_length, truncation=True, return_tensors="pt"
text,
max_length=max_input_len,
truncation=True,
padding=True,
return_tensors="pt",
)
encoded_prompt = encoded_prompt.to(self.device)
output_dict = self.model.generate( # type: ignore
**encoded_prompt,
max_length=kwargs.get("max_length"),
temperature=kwargs.get("temperature"),
top_k=kwargs.get("top_k"),
top_p=kwargs.get("top_p"),
repetition_penalty=kwargs.get("repetition_penalty"),
do_sample=kwargs.get("do_sample") if not self.bitsandbytes else False,
max_new_tokens=kwargs.get("max_new_tokens"),
temperature=kwargs.get("temperature", None),
top_k=kwargs.get("top_k", None),
top_p=kwargs.get("top_p", None),
repetition_penalty=kwargs.get("repetition_penalty", None),
do_sample=kwargs.get("do_sample", None) if not self.bitsandbytes else False,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
num_return_sequences=kwargs.get("num_return_sequences"),
num_return_sequences=kwargs.get("num_return_sequences", None),
output_scores=True,
return_dict_in_generate=True,
)
@ -168,14 +181,15 @@ class HuggingFaceModel(Model):
def __init__(
self,
model_name_or_path: str,
model_config: str,
cache_dir: str,
device: int,
use_accelerate: bool,
use_parallelize: bool,
use_bitsandbytes: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
model_config: Optional[str] = None,
cache_dir: Optional[str] = None,
device: int = 0,
use_accelerate: bool = False,
use_parallelize: bool = False,
use_bitsandbytes: bool = False,
use_deepspeed: bool = False,
perc_max_gpu_mem_red: float = 1.0,
use_fp16: bool = False,
):
"""
Initialize model.
@ -190,11 +204,15 @@ class HuggingFaceModel(Model):
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
if use_accelerate and use_parallelize:
raise ValueError("Cannot use both accelerate and parallelize")
if sum([use_accelerate, use_parallelize, use_bitsandbytes, use_deepspeed]) > 1:
raise ValueError(
"Only one of use_accelerate, use_parallelize, "
"use_bitsandbytes, use_deepspeed can be set to True"
)
# Check if providing path
self.model_path = model_name_or_path
if Path(self.model_path).exists() and Path(self.model_path).is_dir():
@ -206,17 +224,19 @@ class HuggingFaceModel(Model):
print("Model Name:", self.model_name, "Model Path:", self.model_path)
try:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left"
self.model_name, truncation_side="left", padding_side="left"
)
except ValueError:
tokenizer = AutoTokenizer.from_pretrained(
self.model_name, truncation_side="left", use_fast=False
self.model_name,
truncation_side="left",
padding_side="left",
use_fast=False,
)
dtype = torch.float16 if use_fp16 else "auto"
if use_bitsandbytes:
print("WARNING!!! Cannot use sampling with bitsandbytes.")
max_memory = get_max_memory(perc_max_gpu_mem_red)
print(max_memory)
model = MODEL_REGISTRY[self.model_name].from_pretrained( # type: ignore
self.model_path,
cache_dir=cache_dir,
@ -251,6 +271,9 @@ class HuggingFaceModel(Model):
elif use_parallelize:
model.parallelize()
device = 0
elif use_deepspeed:
self._dispatch_deepspeed_model(model)
device = 0
else:
if device > -1:
torch_device = (
@ -258,21 +281,39 @@ class HuggingFaceModel(Model):
if (device == -1 or not torch.cuda.is_available())
else torch.device(f"cuda:{device}")
)
print("T", torch_device)
model = model.to(torch_device) # type: ignore
self.pipeline = Pipeline( # type: ignore
model=model,
tokenizer=tokenizer,
device=device,
bitsandbytes=use_bitsandbytes,
is_encdec=self.is_encdec,
)
# Autogregressive models generate the input, too
self.returns_input = not self.is_encdec
def get_init_params(self) -> Dict:
"""Return init params to determine what model is being used."""
return {"model_name": self.model_name, "model_path": self.model_path}
def _dispatch_deepspeed_model(
self, model: PreTrainedModel
) -> deepspeed.InferenceEngine:
"""
Load model with deepspeed.
Adapted from https://www.deepspeed.ai/tutorials/inference-tutorial/
Args:
model: loaded hugging face model
"""
model = deepspeed.init_inference(
model=model,
mp_size=1,
dtype=model.dtype,
replace_method="auto",
replace_with_kernel_inject=True,
)
return model
def _dispatch_accelerate_model(
self, model: PreTrainedModel, perc_max_gpu_mem_red: float
) -> None:
@ -286,9 +327,6 @@ class HuggingFaceModel(Model):
model: loaded hugging face model
perc_max_gpu_mem_red: percent memory reduction
"""
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils.modeling import get_max_memory
model.tie_weights() # type: ignore
# Get the model where we can infer devices from
if hasattr(model, "model"):
@ -301,7 +339,7 @@ class HuggingFaceModel(Model):
model_getter = ""
# Decrease max mem
max_memory = {
k: int(perc_max_gpu_mem_red * v) for k, v in get_max_memory().items()
k: int(perc_max_gpu_mem_red * v) for k, v in acc_get_max_memory().items()
}
raw_device_map = infer_auto_device_map(
main_model,
@ -333,7 +371,9 @@ class HuggingFaceModel(Model):
return
@torch.no_grad()
def generate(self, prompt: str, **kwargs: Any) -> List[Tuple[str, float]]:
def generate(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[Tuple[str, float]]:
"""
Generate the prompt from model.
@ -345,16 +385,12 @@ class HuggingFaceModel(Model):
Returns:
list of generated text (list of length 1 for 1 generation).
"""
num_return = kwargs.get("n")
max_input_len = self.pipeline.max_length - kwargs.get("max_tokens")
# Add tokens for length
encoded_prompt_with_special = self.pipeline.tokenizer.encode(
prompt, max_length=max_input_len, truncation=True
)
num_return = kwargs.get("n", 1)
if isinstance(prompt, list) and num_return > 1:
raise ValueError("In batch generate, n must be 1.")
result = self.pipeline(
prompt,
max_input_length=max_input_len,
max_length=kwargs.get("max_tokens") + len(encoded_prompt_with_special),
max_new_tokens=kwargs.get("max_tokens"),
temperature=kwargs.get("temperature"),
repetition_penalty=kwargs.get("repetition_penalty"),
top_k=kwargs.get("top_k"),
@ -362,24 +398,16 @@ class HuggingFaceModel(Model):
do_sample=kwargs.get("do_sample"),
num_return_sequences=num_return,
)
if num_return == 1:
final_results = [
(
cast(str, result[0]["generated_text"]),
sum(cast(List[float], result[0]["logprobs"])),
)
]
else:
final_results = [
(cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])))
for r in result
]
final_results = [
(cast(str, r["generated_text"]), sum(cast(List[float], r["logprobs"])))
for r in result
]
return final_results
@torch.no_grad()
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> Tuple[str, float]:
self, prompt: Union[str, List[str]], gold_choices: List[str], **kwargs: Any
) -> List[Tuple[str, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.
@ -390,6 +418,8 @@ class HuggingFaceModel(Model):
Returns:
the returned gold choice
"""
if isinstance(prompt, str):
prompt = [prompt]
max_input_len = self.pipeline.max_length
if self.is_encdec:
# Adapted from https://github.com/bigscience-workshop/t-zero
@ -425,10 +455,13 @@ class HuggingFaceModel(Model):
}
# Add choice tokens + mask
features["labels"] = [
tokenized_targets[k]["input_ids"] for k in range(len(gold_choices))
[tokenized_targets[k]["input_ids"]] * len(tokenized_inputs["input_ids"])
for k in range(len(gold_choices))
]
features["labels_attention_mask"] = [
tokenized_targets[k]["attention_mask"] for k in range(len(gold_choices))
[tokenized_targets[k]["attention_mask"]]
* len(tokenized_inputs["input_ids"])
for k in range(len(gold_choices))
]
else:
tokenized_inputs = self.pipeline.tokenizer(
@ -455,55 +488,70 @@ class HuggingFaceModel(Model):
max_effective_input_len = 0
for tokenized_targ in tokenized_targets:
for k in tokenized_inputs.keys():
# Make sure to leave room for the outputs
features[k].append(
tokenized_inputs[k][
: min(
len(tokenized_inputs[k]),
max_input_len - len(tokenized_targ[k]),
)
]
+ tokenized_targ[k]
)
max_effective_input_len = max(
max_effective_input_len, len(features[k][-1])
)
batched_features = []
for prompt_i in range(len(tokenized_inputs[k])):
# Make sure to leave room for the outputs
batched_features.append(
tokenized_inputs[k][prompt_i][
: min(
len(tokenized_inputs[k][prompt_i]),
max_input_len - len(tokenized_targ[k]),
)
]
+ tokenized_targ[k]
)
max_effective_input_len = max(
max_effective_input_len, len(batched_features[-1])
)
features[k].append(batched_features)
# Manuall add labels_attention_mask
features["labels_attention_mask"].append(
[0]
* min(
len(tokenized_inputs["input_ids"]),
max_input_len - len(tokenized_targ["input_ids"]),
batched_features = []
for prompt_i in range(len(tokenized_inputs["input_ids"])):
batched_features.append(
[0]
* min(
len(tokenized_inputs["input_ids"][prompt_i]),
max_input_len - len(tokenized_targ["input_ids"]),
)
+ [1] * len(tokenized_targ["input_ids"])
)
+ [1] * len(tokenized_targ["input_ids"])
)
features["labels_attention_mask"].append(batched_features)
# Manually pad to max effective length
for k in features.keys():
for i in range(len(features[k])):
if k == "input_ids":
features[k][i] += [self.pipeline.tokenizer.pad_token_id] * (
max_effective_input_len - len(features[k][i])
)
elif k in ["attention_mask", "labels_attention_mask"]:
features[k][i] += [0] * (
max_effective_input_len - len(features[k][i])
)
else:
raise ValueError(f"Unknown key {k} for decoder only models")
for targ_i in range(len(features[k])):
for prompt_i in range(len(features[k][targ_i])):
if k == "input_ids":
features[k][targ_i][prompt_i] += [
self.pipeline.tokenizer.pad_token_id
] * (
max_effective_input_len
- len(features[k][targ_i][prompt_i])
)
elif k in ["attention_mask", "labels_attention_mask"]:
features[k][targ_i][prompt_i] += [0] * (
max_effective_input_len
- len(features[k][targ_i][prompt_i])
)
else:
raise ValueError(f"Unknown key {k} for decoder only models")
features["labels"] = features["input_ids"]
# Convert to tensors
tensor_features = {}
for k in features:
tensor_features[k] = torch.LongTensor(features[k]).to(self.pipeline.device)
if self.is_encdec:
gold_l, bsz, seq_len = tensor_features["labels"].shape
stacked_logits = self.pipeline.model( # type: ignore
input_ids=tensor_features["input_ids"],
attention_mask=tensor_features["attention_mask"],
labels=tensor_features["labels"],
input_ids=tensor_features["input_ids"].reshape(gold_l * bsz, -1),
attention_mask=tensor_features["attention_mask"].reshape(
gold_l * bsz, -1
),
labels=tensor_features["labels"].reshape(gold_l * bsz, -1),
).logits
stacked_logits = stacked_logits.reshape(gold_l, bsz, seq_len, -1)
# Adapted from https://github.com/bigscience-workshop/t-zero
masked_log_probs = tensor_features["labels_attention_mask"].unsqueeze(
-1
@ -525,12 +573,57 @@ class HuggingFaceModel(Model):
* torch.log_softmax(stacked_logits.float(), dim=-1)[..., :-1, :]
)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, tensor_features["labels"][:, 1:].unsqueeze(-1)
masked_log_probs, -1, tensor_features["labels"][..., 1:].unsqueeze(-1)
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_log_prob = seq_token_log_probs.sum(dim=-1)
# Averaging over output sequence length for GPT
if not self.is_encdec:
seq_log_prob = seq_log_prob * (1 / (seq_token_log_probs != 0).sum(dim=-1))
prediction = seq_log_prob.argmax(dim=-1).item()
return gold_choices[int(prediction)], seq_log_prob[int(prediction)].item()
prediction = seq_log_prob.argmax(dim=0)
return [
(gold_choices[int(p)], seq_log_prob[int(p), i].item())
for i, p in enumerate(prediction)
]
@torch.no_grad()
def score_sequence(
self, prompt: Union[str, List[str]], **kwargs: Any
) -> List[float]:
"""
Score a sequence of choices.
Args:
prompt (:obj:`str` or :obj:`List[str]`):
The prompt to score the choices against.
**kwargs:
Additional keyword arguments passed along to the :obj:`__call__` method.
"""
if isinstance(prompt, str):
prompt = [prompt]
encoded_prompt = self.pipeline.tokenizer(
prompt,
max_length=self.pipeline.max_length,
truncation=True,
padding=True,
return_tensors="pt",
)
encoded_prompt["labels"] = encoded_prompt["input_ids"].clone()
encoded_prompt = encoded_prompt.to(self.pipeline.device)
logits = self.pipeline.model( # type: ignore
**encoded_prompt,
).logits
# For causal decoders, shift logts and labels
labels_attention_mask = encoded_prompt["attention_mask"].unsqueeze(-1)[
..., 1:, :
]
masked_log_probs = (
labels_attention_mask.float()
* torch.log_softmax(logits.float(), dim=-1)[..., :-1, :]
)
seq_token_log_probs = torch.gather(
masked_log_probs, -1, encoded_prompt["labels"][..., 1:].unsqueeze(-1)
)
seq_token_log_probs = seq_token_log_probs.squeeze(dim=-1)
seq_log_prob = seq_token_log_probs.sum(dim=-1)
return seq_log_prob.tolist()

@ -16,6 +16,7 @@ class Model(ABC):
use_accelerate: bool,
use_parallelize: bool,
use_bitsandbytes: bool,
use_deepspeed: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
):
@ -32,6 +33,7 @@ class Model(ABC):
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
@ -60,7 +62,7 @@ class Model(ABC):
@abstractmethod
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> Tuple[str, float]:
) -> List[Tuple[str, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.

@ -25,6 +25,7 @@ class ZooModel(Model):
use_accelerate: bool,
use_parallelize: bool,
use_bitsandbytes: bool,
use_deepspeed: bool,
perc_max_gpu_mem_red: float,
use_fp16: bool,
):
@ -41,6 +42,7 @@ class ZooModel(Model):
use_accelerate: whether to use accelerate for multi-gpu inference.
use_parallelize: use HF default parallelize
use_bitsandbytes: use HF bits and bytes
use_deepspeed: use deepspeed
perc_max_gpu_mem_red: percent max memory reduction in accelerate
use_fp16: use fp16 for model weights.
"""
@ -82,7 +84,7 @@ class ZooModel(Model):
def logits_scoring(
self, prompt: str, gold_choices: List[str], **kwargs: Any
) -> Tuple[str, float]:
) -> List[Tuple[str, float]]:
"""
Given the prompt and gold choices, choose the best choice with max logits.

@ -6,11 +6,12 @@ strict_optional = false
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"deepspeed",
"numpy",
"tqdm",
"tqdm.auto",
"sqlitedict",
"dill",
"tqdm.auto",
"accelerate",
"accelerate.utils.modeling",
"transformers",

@ -0,0 +1,264 @@
"""Test the HuggingFace API."""
import math
import os
from subprocess import PIPE, Popen
import pytest
from manifest.api.models.huggingface import HuggingFaceModel
NOCUDA = 0
try:
p = Popen(
[
"nvidia-smi",
(
"--query-gpu=index,utilization.gpu,memory.total,memory.used,"
"memory.free,driver_version,name,gpu_serial,display_active,"
"display_mode"
),
"--format=csv,noheader,nounits",
],
stdout=PIPE,
)
except OSError:
NOCUDA = 1
MAXGPU = 0
if NOCUDA == 0:
try:
p = os.popen("nvidia-smi --query-gpu=index --format=csv,noheader,nounits")
i = p.read().split("\n")
MAXGPU = int(i[-2]) + 1
except OSError:
NOCUDA = 1
def test_gpt_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = "Why is the sky green?"
result = model.generate(inputs, max_tokens=10)
assert result is not None
assert len(result) == 1
assert result[0][0] == "\n\nThe sky is green.\n\nThe"
assert math.isclose(round(result[0][1], 3), -11.516)
result = model.generate("Cats are", max_tokens=10)
assert result is not None
assert len(result) == 1
assert result[0][0] == " not the only ones who are being targeted by the"
assert math.isclose(round(result[0][1], 3), -21.069)
result = model.generate(inputs, max_tokens=5)
assert result is not None
assert len(result) == 1
assert result[0][0] == "\n\nThe sky is"
assert math.isclose(round(result[0][1], 3), -6.046)
result = model.logits_scoring(inputs, gold_choices=[" blue sky", " green sky"])
assert result is not None
assert len(result) == 1
assert result[0][0] == " blue sky"
assert math.isclose(round(result[0][1], 3), -6.999)
# Truncate max length
model.pipeline.max_length = 5
result = model.generate(inputs, max_tokens=2)
assert result is not None
assert len(result) == 1
assert result[0][0] == "\n\n"
assert math.isclose(round(result[0][1], 3), -1.414)
def test_encdec_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="google/t5-small-lm-adapt",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = "Why is the sky green?"
result = model.generate(inputs, max_tokens=10)
assert result is not None
assert len(result) == 1
assert result[0][0] == "What is the sky green? What is the sky"
assert math.isclose(round(result[0][1], 3), -7.271)
result = model.generate("Cats are", max_tokens=10)
assert result is not None
assert len(result) == 1
assert result[0][0] == "a great way to get out of the house"
assert math.isclose(round(result[0][1], 3), -13.868)
result = model.generate(inputs, max_tokens=5)
assert result is not None
assert len(result) == 1
assert result[0][0] == "What is the sky green"
assert math.isclose(round(result[0][1], 3), -5.144)
result = model.logits_scoring(inputs, gold_choices=[" blue sky", " green sky"])
assert result is not None
assert len(result) == 1
assert result[0][0] == " green sky"
assert math.isclose(round(result[0][1], 3), -13.538)
# Truncate max length
model.pipeline.max_length = 5
result = model.generate(inputs, max_tokens=2)
assert result is not None
assert len(result) == 1
assert result[0][0] == "Is"
assert math.isclose(round(result[0][1], 3), -4.233)
def test_gpt_score():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = ["Why is the sky green?", "Cats are butterflies"]
result = model.score_sequence(inputs)
assert result is not None
assert len(result) == 2
assert math.isclose(round(result[0], 3), -19.935)
assert math.isclose(round(result[1], 3), -45.831)
def test_batch_gpt_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = ["Why is the sky green?", "Cats are"]
result = model.generate(inputs, max_tokens=10)
assert result is not None
assert len(result) == 2
assert result[0][0] == "\n\nThe sky is green.\n\nThe"
assert math.isclose(round(result[0][1], 3), -11.516)
assert result[1][0] == " not the only ones who are being targeted by the"
assert math.isclose(round(result[1][1], 3), -21.069)
result = model.generate(inputs, max_tokens=5)
assert result is not None
assert len(result) == 2
assert result[0][0] == "\n\nThe sky is"
assert math.isclose(round(result[0][1], 3), -6.046)
assert result[1][0] == " not the only ones who"
assert math.isclose(round(result[1][1], 3), -9.978)
result = model.logits_scoring(
inputs, gold_choices=[" purple sky", " green sky", " blue sky"]
)
assert result is not None
assert len(result) == 2
assert result[0][0] == " blue sky"
assert math.isclose(round(result[0][1], 3), -6.999)
assert result[1][0] == " blue sky"
assert math.isclose(round(result[1][1], 3), -8.212)
# Truncate max length
model.pipeline.max_length = 5
result = model.generate(inputs, max_tokens=2)
assert result is not None
assert len(result) == 2
assert result[0][0] == "\n\n"
assert math.isclose(round(result[0][1], 3), -1.414)
assert result[1][0] == " not the"
assert math.isclose(round(result[1][1], 3), -6.246)
def test_batch_encdec_generate():
"""Test pipeline generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="google/t5-small-lm-adapt",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=False,
use_fp16=False,
device=-1,
)
inputs = ["Why is the sky green?", "Cats are"]
result = model.generate(inputs, max_tokens=10)
assert result is not None
assert len(result) == 2
assert result[0][0] == "What is the sky green? What is the sky"
assert math.isclose(round(result[0][1], 3), -7.271)
assert result[1][0] == "a great way to get out of the house"
assert math.isclose(round(result[1][1], 3), -13.868)
result = model.generate(inputs, max_tokens=5)
assert result is not None
assert len(result) == 2
assert result[0][0] == "What is the sky green"
assert math.isclose(round(result[0][1], 3), -5.144)
assert result[1][0] == "a great way to"
assert math.isclose(round(result[1][1], 3), -6.353)
result = model.logits_scoring(
inputs, gold_choices=[" purple sky", " green sky", " blue sky"]
)
assert result is not None
assert len(result) == 2
assert result[0][0] == " green sky"
assert math.isclose(round(result[0][1], 3), -13.538)
assert result[1][0] == " blue sky"
assert math.isclose(round(result[1][1], 3), -41.503)
# Truncate max length
model.pipeline.max_length = 5
result = model.generate(inputs, max_tokens=2)
assert result is not None
assert len(result) == 2
assert result[0][0] == "Is"
assert math.isclose(round(result[0][1], 3), -4.233)
assert result[1][0] == "a"
assert math.isclose(round(result[1][1], 3), -1.840)
@pytest.mark.skipif(
(NOCUDA == 1 or MAXGPU == 0), reason="No cuda or GPUs found through nvidia-smi"
)
def test_gpt_deepspeed_generate():
"""Test deepspeed generation from a gpt model."""
model = HuggingFaceModel(
model_name_or_path="gpt2",
use_accelerate=False,
use_parallelize=False,
use_bitsandbytes=False,
use_deepspeed=True,
use_fp16=False,
device=0,
)
inputs = "Why is the sky green?"
result = model.generate(inputs, max_tokens=10)
assert result is not None
assert len(result) == 1
assert result[0][0] == "\n\nThe sky is green.\n\nThe"
assert math.isclose(round(result[0][1], 3), -11.517)
Loading…
Cancel
Save