fix: saving name

This commit is contained in:
Zach Nussbaum 2023-04-08 20:56:13 +00:00
parent 305fe3d444
commit 6de58dd1fc

View File

@ -192,7 +192,7 @@ def train(accelerator, config):
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/-epoch_{epoch}",
f"{config['output_dir']}/epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),