From 812b8070033021dcb97026f219609addbc754a67 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 28 Mar 2023 18:47:58 +0000 Subject: [PATCH] fix: log for multiple epochs --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 4eddccb5..4344ee24 100644 --- a/train.py +++ b/train.py @@ -127,7 +127,8 @@ def train(accelerator, config): # log LR in case something weird happens if step > 0 and step % (config["eval_every"] // 10) == 0: if config["wandb"]: - accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=step) + curr_step = step + epoch * len(train_dataloader) + accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step) if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: optimizer.step() @@ -151,7 +152,8 @@ def train(accelerator, config): } if config["wandb"]: - accelerator.log({**log_train, **log_val}, step=step) + curr_step = step + epoch * len(train_dataloader) + accelerator.log({**log_train, **log_val}, step=curr_step) accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}") accelerator.print(format_metrics(log_train, "train", f" step {step} "))