Add benchmark scripts (#319)

This PR:

- Adds benchmark scripts for inference, forward pass, and full training step (e.g. used for experiments in our paper).
- Fixes bug with dtypes in `petals.DistributedBloomForSequenceClassification`.
- (minor refactor) Moves `DTYPE_MAP` to `petals.constants` as a useful constant.
pull/337/head
Alexander Borzunov 11 months ago committed by GitHub
parent fecee8c4dc
commit d126ee3053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,69 @@
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import torch
from hivemind.utils.logging import get_logger
from petals import AutoDistributedModel
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--n_steps", type=int, default=100)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
@torch.inference_mode()
def benchmark_forward(process_idx, args):
model = AutoDistributedModel.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
)
logger.info(f"Created model: {process_idx=} {model.device=}")
torch.manual_seed(42)
for step in range(args.n_steps):
if step == args.warmup_steps:
start_time = perf_counter()
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
h = model(input_ids)
# We don't use model.lm_head
logger.info(f"{process_idx=} Fwd end")
if step >= args.warmup_steps:
speed = step / (perf_counter() - start_time) * input_ids.numel()
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
logger.info(f"Final result: {process_idx=} {speed=:.3f}")
if __name__ == "__main__":
main()

@ -0,0 +1,64 @@
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import torch
from hivemind.utils.logging import get_logger
from transformers import AutoTokenizer
from petals import AutoDistributedModelForCausalLM
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=2048)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
@torch.inference_mode()
def benchmark_inference(process_idx, args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
model = AutoDistributedModelForCausalLM.from_pretrained(
args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
result = ""
with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
for step in range(args.seq_len):
if step == args.warmup_steps:
start_time = perf_counter()
outputs = model.generate(max_new_tokens=1, session=sess)
result += tokenizer.decode(outputs[0])
if step >= args.warmup_steps:
speed = step / (perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} {speed=:.3f}")
logger.info(f"Final result: {process_idx=} {speed=:.3f}")
if __name__ == "__main__":
main()

@ -0,0 +1,101 @@
#!/usr/bin/env python3
import argparse
import multiprocessing as mp
from time import perf_counter
import numpy as np
import torch
from hivemind.utils.logging import get_logger
from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
logger = get_logger()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="bigscience/bloom")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--task", type=str, default="cls")
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
parser.add_argument("--torch_dtype", type=str, default="bfloat16")
parser.add_argument("--n_processes", type=str, default=1)
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--pre_seq_len", type=int, default=16)
parser.add_argument("--n_steps", type=int, default=10)
parser.add_argument("--batch_size", type=int, required=True)
parser.add_argument("--warmup_steps", type=int, default=1)
args = parser.parse_args()
assert args.task in ["cls", "causal_lm"]
if args.n_processes == "n_gpus":
args.n_processes = torch.cuda.device_count()
else:
args.n_processes = int(args.n_processes)
processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
for proc in processes:
proc.start()
for proc in processes:
proc.join()
def benchmark_training(process_idx, args):
if args.task == "cls":
model = AutoDistributedModelForSequenceClassification.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
tuning_mode="deep_ptune",
pre_seq_len=args.pre_seq_len,
num_labels=2,
)
elif args.task == "causal_lm":
model = AutoDistributedModelForCausalLM.from_pretrained(
args.model,
initial_peers=args.initial_peers,
torch_dtype=DTYPE_MAP[args.torch_dtype],
tuning_mode="deep_ptune",
pre_seq_len=args.pre_seq_len,
)
model = model.to(args.device)
opt = torch.optim.Adam(model.parameters())
logger.info(f"Created model: {process_idx=} {model.device=}")
torch.manual_seed(42)
fwd_times = []
bwd_times = []
for step in range(args.n_steps):
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
if args.task == "cls":
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
else:
labels = input_ids
logger.info(f"{process_idx=} {step=} Forward")
start_time = perf_counter()
outputs = model(input_ids, labels=labels)
fwd_times.append(perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} Backward")
start_time = perf_counter()
outputs.loss.backward()
bwd_times.append(perf_counter() - start_time)
logger.info(f"{process_idx=} {step=} Optimizer step")
opt.step()
opt.zero_grad()
if step >= args.warmup_steps:
fwd_speed = input_ids.numel() / np.mean(fwd_times[1:])
bwd_speed = input_ids.numel() / np.mean(bwd_times[1:])
logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
if __name__ == "__main__":
main()

@ -6,8 +6,8 @@ from hivemind.utils.limits import increase_file_limit
from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import DTYPE_MAP, Server
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.utils.version import validate_version
logger = get_logger(__name__)

@ -1,3 +1,5 @@
import torch
PUBLIC_INITIAL_PEERS = [
"/dns/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
"/dns6/bootstrap1.petals.ml/tcp/31337/p2p/QmedTaZXmULqwspJXz44SsPZyTNKxhnnFvYRajfH7MGhCY",
@ -7,3 +9,5 @@ PUBLIC_INITIAL_PEERS = [
# The reachability API is currently used only when connecting to the public swarm
REACHABILITY_API_URL = "http://health.petals.ml"
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -128,7 +128,7 @@ class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSeq
self.num_labels = config.num_labels
self.transformer = DistributedBloomModel(config)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()

@ -19,6 +19,7 @@ from huggingface_hub import get_hf_file_metadata, hf_hub_url
from transformers import PretrainedConfig
from transformers.utils import get_file_from_repo
from petals.constants import DTYPE_MAP
from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
@ -170,6 +171,3 @@ def _load_state_dict_from_file(
except Exception as e:
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay)
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

@ -16,13 +16,13 @@ from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
from petals.server.block_utils import get_block_size, resolve_block_dtype
from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block
from petals.server.from_pretrained import load_pretrained_block
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability

Loading…
Cancel
Save