2023-06-29 21:12:59 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import multiprocessing as mp
|
|
|
|
from time import perf_counter
|
|
|
|
|
2023-06-30 00:18:43 +00:00
|
|
|
import numpy as np
|
2023-06-29 21:12:59 +00:00
|
|
|
import torch
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
|
|
from petals import AutoDistributedModel
|
|
|
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2023-08-08 15:10:27 +00:00
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
|
parser.add_argument("--model", type=str, required=True, help="Model")
|
|
|
|
parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
|
2023-08-09 12:50:02 +00:00
|
|
|
parser.add_argument("--torch_dtype", type=str, default="float32", help="Torch dtype")
|
2023-08-08 15:10:27 +00:00
|
|
|
parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
|
|
|
|
parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
|
|
|
|
parser.add_argument("--n_steps", type=int, default=100, help="Number of benchmark steps")
|
|
|
|
parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
|
|
|
|
parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
|
2023-06-29 21:12:59 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.n_processes == "n_gpus":
|
|
|
|
args.n_processes = torch.cuda.device_count()
|
|
|
|
else:
|
|
|
|
args.n_processes = int(args.n_processes)
|
|
|
|
|
2023-08-09 12:50:02 +00:00
|
|
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
|
|
|
processes = [mp.Process(target=benchmark_forward, args=(i, args, pipe_send)) for i in range(args.n_processes)]
|
2023-06-29 21:12:59 +00:00
|
|
|
for proc in processes:
|
|
|
|
proc.start()
|
|
|
|
for proc in processes:
|
|
|
|
proc.join()
|
|
|
|
|
2023-08-09 12:50:02 +00:00
|
|
|
speed = np.mean([pipe_recv.recv() for _ in range(args.n_processes)])
|
|
|
|
logger.info(f"Final result: {speed=:.2f}")
|
|
|
|
|
2023-06-29 21:12:59 +00:00
|
|
|
|
|
|
|
@torch.inference_mode()
|
2023-08-09 12:50:02 +00:00
|
|
|
def benchmark_forward(process_idx, args, result_pipe):
|
2023-06-29 21:12:59 +00:00
|
|
|
model = AutoDistributedModel.from_pretrained(
|
|
|
|
args.model,
|
|
|
|
initial_peers=args.initial_peers,
|
|
|
|
torch_dtype=DTYPE_MAP[args.torch_dtype],
|
|
|
|
)
|
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
|
|
torch.manual_seed(42)
|
2023-06-30 00:18:43 +00:00
|
|
|
step_times = []
|
|
|
|
for step in range(args.warmup_steps + args.n_steps):
|
|
|
|
start_time = perf_counter()
|
2023-06-29 21:12:59 +00:00
|
|
|
|
|
|
|
input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
|
|
|
|
|
|
|
|
logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
|
|
|
|
h = model(input_ids)
|
|
|
|
# We don't use model.lm_head
|
|
|
|
logger.info(f"{process_idx=} Fwd end")
|
|
|
|
|
|
|
|
if step >= args.warmup_steps:
|
2023-06-30 00:18:43 +00:00
|
|
|
step_times.append(perf_counter() - start_time)
|
|
|
|
speed = input_ids.numel() / np.mean(step_times)
|
|
|
|
logger.info(f"{process_idx=} {step=} {speed=:.2f}")
|
2023-06-29 21:12:59 +00:00
|
|
|
|
2023-08-09 12:50:02 +00:00
|
|
|
result_pipe.send(speed)
|
2023-06-29 21:12:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|