fix: try except push

This commit is contained in:
Zach 2023-04-05 20:42:22 +00:00
parent 399a65e779
commit a57adb0344

View File

@ -13,6 +13,7 @@ from torchmetrics import MeanMetric
from tqdm import tqdm
import wandb
torch.backends.cuda.matmul.allow_tf32 = True
def format_metrics(metrics, split, prefix=""):
log = f"[{split}]" + prefix
@ -192,9 +193,20 @@ def train(accelerator, config):
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
except Exception as e:
accelerator.print(e)
accelerator.print(f"Failed to push to hub")
unwrapped_model.save_pretrained(
f"{config['output_dir']}/-epoch_{epoch}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=accelerator.get_state_dict(model),
)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)