Add benchmark scripts (#319)
This PR: - Adds benchmark scripts for inference, forward pass, and full training step (e.g. used for experiments in our paper). - Fixes bug with dtypes in `petals.DistributedBloomForSequenceClassification`. - (minor refactor) Moves `DTYPE_MAP` to `petals.constants` as a useful constant.pull/337/head
parent
fecee8c4dc
commit
d126ee3053
@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModel
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--n_processes", type=str, default=1)
|
||||
parser.add_argument("--seq_len", type=int, default=128)
|
||||
parser.add_argument("--n_steps", type=int, default=100)
|
||||
parser.add_argument("--batch_size", type=int, required=True)
|
||||
parser.add_argument("--warmup_steps", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_forward(process_idx, args):
|
||||
model = AutoDistributedModel.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
for step in range(args.n_steps):
|
||||
if step == args.warmup_steps:
|
||||
start_time = perf_counter()
|
||||
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
||||
|
||||
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
||||
h = model(input_ids)
|
||||
# We don't use model.lm_head
|
||||
logger.info(f"{process_idx=} Fwd end")
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
speed = step / (perf_counter() - start_time) * input_ids.numel()
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
|
||||
|
||||
logger.info(f"Final result: {process_idx=} {speed=:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,64 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--n_processes", type=str, default=1)
|
||||
parser.add_argument("--seq_len", type=int, default=2048)
|
||||
parser.add_argument("--warmup_steps", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def benchmark_inference(process_idx, args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
|
||||
|
||||
result = ""
|
||||
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
|
||||
for step in range(args.seq_len):
|
||||
if step == args.warmup_steps:
|
||||
start_time = perf_counter()
|
||||
|
||||
outputs = model.generate(max_new_tokens=1, session=sess)
|
||||
result += tokenizer.decode(outputs[0])
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
speed = step / (perf_counter() - start_time)
|
||||
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
|
||||
|
||||
logger.info(f"Final result: {process_idx=} {speed=:.3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from time import perf_counter
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
|
||||
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom")
|
||||
parser.add_argument("--device", type=str, default="cpu")
|
||||
parser.add_argument("--task", type=str, default="cls")
|
||||
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
|
||||
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
|
||||
parser.add_argument("--n_processes", type=str, default=1)
|
||||
parser.add_argument("--seq_len", type=int, default=128)
|
||||
parser.add_argument("--pre_seq_len", type=int, default=16)
|
||||
parser.add_argument("--n_steps", type=int, default=10)
|
||||
parser.add_argument("--batch_size", type=int, required=True)
|
||||
parser.add_argument("--warmup_steps", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.task in ["cls", "causal_lm"]
|
||||
|
||||
if args.n_processes == "n_gpus":
|
||||
args.n_processes = torch.cuda.device_count()
|
||||
else:
|
||||
args.n_processes = int(args.n_processes)
|
||||
|
||||
processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
|
||||
for proc in processes:
|
||||
proc.start()
|
||||
for proc in processes:
|
||||
proc.join()
|
||||
|
||||
|
||||
def benchmark_training(process_idx, args):
|
||||
if args.task == "cls":
|
||||
model = AutoDistributedModelForSequenceClassification.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
num_labels=2,
|
||||
)
|
||||
elif args.task == "causal_lm":
|
||||
model = AutoDistributedModelForCausalLM.from_pretrained(
|
||||
args.model,
|
||||
initial_peers=args.initial_peers,
|
||||
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
||||
tuning_mode="deep_ptune",
|
||||
pre_seq_len=args.pre_seq_len,
|
||||
)
|
||||
model = model.to(args.device)
|
||||
opt = torch.optim.Adam(model.parameters())
|
||||
logger.info(f"Created model: {process_idx=} {model.device=}")
|
||||
|
||||
torch.manual_seed(42)
|
||||
fwd_times = []
|
||||
bwd_times = []
|
||||
for step in range(args.n_steps):
|
||||
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
|
||||
if args.task == "cls":
|
||||
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
||||
else:
|
||||
labels = input_ids
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Forward")
|
||||
start_time = perf_counter()
|
||||
outputs = model(input_ids, labels=labels)
|
||||
fwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Backward")
|
||||
start_time = perf_counter()
|
||||
outputs.loss.backward()
|
||||
bwd_times.append(perf_counter() - start_time)
|
||||
|
||||
logger.info(f"{process_idx=} {step=} Optimizer step")
|
||||
opt.step()
|
||||
opt.zero_grad()
|
||||
|
||||
if step >= args.warmup_steps:
|
||||
fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
|
||||
bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
|
||||
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
|
||||
|
||||
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue