fix: eval func

pull/913/head
Zach Nussbaum 1 year ago
parent 8dd99cc00a
commit ad33b83a48

@ -169,7 +169,7 @@ def train(accelerator, config):
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(config, model, val_dataloader)
val_loss = evaluate(model, val_dataloader)
log_train = {
"train_loss": train_loss.compute()

Loading…
Cancel
Save