From ad33b83a48d8aefb09799f2f15d5ff5be0f6fd97 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Tue, 4 Apr 2023 23:25:37 +0000 Subject: [PATCH] fix: eval func --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 6c2c0515..1f0b4852 100644 --- a/train.py +++ b/train.py @@ -169,7 +169,7 @@ def train(accelerator, config): accelerator.save_state(f"{config['output_dir']}/step_{curr_step}") if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1): - val_loss = evaluate(config, model, val_dataloader) + val_loss = evaluate(model, val_dataloader) log_train = { "train_loss": train_loss.compute()