|
|
|
@ -19,6 +19,7 @@ petals.client.sequential_autograd.MAX_TOKENS_IN_BATCH = 1024
|
|
|
|
|
def main():
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-petals")
|
|
|
|
|
parser.add_argument("--device", type=str, default="cpu")
|
|
|
|
|
parser.add_argument("--task", type=str, default="cls")
|
|
|
|
|
parser.add_argument("-i", "--initial_peers", type=str, nargs='+',
|
|
|
|
|
default=["/dns/bench.petals.ml/tcp/31337/p2p/QmehSoMKScoMF3HczLwaLVnw2Lgsap4bhAMrULEzGc1fSV"])
|
|
|
|
@ -53,6 +54,7 @@ def benchmark_training(process_idx, args):
|
|
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
|
|
args.model, initial_peers=args.initial_peers, tuning_mode="deep_ptune",
|
|
|
|
|
pre_seq_len=args.pre_seq_len)
|
|
|
|
|
model = model.to(args.device)
|
|
|
|
|
opt = torch.optim.Adam(model.parameters())
|
|
|
|
|
logger.info(f"Created model: {process_idx=} {model.device=}")
|
|
|
|
|
|
|
|
|
@ -60,9 +62,9 @@ def benchmark_training(process_idx, args):
|
|
|
|
|
fwd_times = []
|
|
|
|
|
bwd_times = []
|
|
|
|
|
for step in range(args.n_steps):
|
|
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len))
|
|
|
|
|
input_ids = torch.randint(100, 10000, size=(args.batch_size, args.seq_len), device=args.device)
|
|
|
|
|
if args.task == "cls":
|
|
|
|
|
labels = torch.randint(0, 2, size=[args.batch_size])
|
|
|
|
|
labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
|
|
|
|
|
else:
|
|
|
|
|
labels = input_ids
|
|
|
|
|
|
|
|
|
|