From 838b19bea53fd323dc107771315f9e4dc7c7da47 Mon Sep 17 00:00:00 2001 From: Zach Date: Wed, 5 Apr 2023 20:42:22 +0000 Subject: [PATCH] fix: try except push --- train.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 0f35bc5b..2cf92aa8 100644 --- a/train.py +++ b/train.py @@ -13,6 +13,7 @@ from torchmetrics import MeanMetric from tqdm import tqdm import wandb +torch.backends.cuda.matmul.allow_tf32 = True def format_metrics(metrics, split, prefix=""): log = f"[{split}]" + prefix @@ -192,9 +193,20 @@ def train(accelerator, config): accelerator.print(f"Pushing to HF hub") accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - if accelerator.is_main_process: - unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") + + unwrapped_model.save_pretrained( + f"{config['output_dir']}/-epoch_{epoch}", + is_main_process=accelerator.is_main_process, + save_function=accelerator.save, + state_dict=accelerator.get_state_dict(model), + ) accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model)