fix: eval func

This commit is contained in:
Zach Nussbaum 2023-04-04 23:25:37 +00:00
parent 8dd99cc00a
commit ad33b83a48

View File

@ -169,7 +169,7 @@ def train(accelerator, config):
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
val_loss = evaluate(config, model, val_dataloader)
val_loss = evaluate(model, val_dataloader)
log_train = {
"train_loss": train_loss.compute()