fix: num training steps for lr decay

pull/335/head
Zach Nussbaum 1 year ago
parent 195f8a7d4e
commit 9dfd8e1a7c

@ -100,7 +100,7 @@ def train(accelerator, config):
name="cosine",
optimizer=optimizer,
num_warmup_steps=config["warmup_steps"] * accelerator.num_processes,
num_training_steps=total_num_steps * accelerator.num_processes,
num_training_steps=total_num_steps,
)
else:
scheduler = DummyScheduler(

Loading…
Cancel
Save