From 7737fe1facc6e647f4567cf8fe0f2c9feab83869 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Thu, 25 May 2023 12:03:50 +0000 Subject: [PATCH] benchmark_training.py: Add device arg --- src/petals/cli/benchmark_training.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/petals/cli/benchmark_training.py b/src/petals/cli/benchmark_training.py index 8a04d14..95a1d15 100755 --- a/src/petals/cli/benchmark_training.py +++ b/src/petals/cli/benchmark_training.py @@ -19,6 +19,7 @@ petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024 def main(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="bigscience/bloom-petals") + parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--task", type=str, default="cls") parser.add_argument("-i", "--initial_peers", type=str, nargs='+', default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"]) @@ -53,6 +54,7 @@ def benchmark_training(process_idx, args): model = DistributedBloomForCausalLM.from_pretrained( args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune", pre_seq_len=args.pre_seq_len) + model = model.to(args.device) opt = torch.optim.Adam(model.parameters()) logger.info(f"Created model: {process_idx=} {model.device=}") @@ -60,9 +62,9 @@ def benchmark_training(process_idx, args): fwd_times = [] bwd_times = [] for step in range(args.n_steps): - input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len)) + input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len), device=args.device) if args.task == "cls": - labels = torch.randint(0, 2, size=[args.batch_size]) + labels = torch.randint(0, 2, size=[args.batch_size], device=args.device) else: labels = input_ids