|
|
@ -169,7 +169,7 @@ def train(accelerator, config):
|
|
|
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
|
|
|
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
|
|
|
|
|
|
|
|
|
|
|
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
|
|
|
if step > 0 and (step % config["eval_every"] == 0 or step == len(train_dataloader) - 1):
|
|
|
|
val_loss = evaluate(config, model, val_dataloader)
|
|
|
|
val_loss = evaluate(model, val_dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
log_train = {
|
|
|
|
log_train = {
|
|
|
|
"train_loss": train_loss.compute()
|
|
|
|
"train_loss": train_loss.compute()
|
|
|
|