|
|
|
@ -13,6 +13,7 @@ from torchmetrics import MeanMetric
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
import wandb
|
|
|
|
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
|
|
|
|
|
def format_metrics(metrics, split, prefix=""):
|
|
|
|
|
log = f"[{split}]" + prefix
|
|
|
|
@ -192,9 +193,20 @@ def train(accelerator, config):
|
|
|
|
|
accelerator.print(f"Pushing to HF hub")
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if accelerator.is_main_process:
|
|
|
|
|
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
accelerator.print(e)
|
|
|
|
|
accelerator.print(f"Failed to push to hub")
|
|
|
|
|
|
|
|
|
|
unwrapped_model.save_pretrained(
|
|
|
|
|
f"{config['output_dir']}/-epoch_{epoch}",
|
|
|
|
|
is_main_process=accelerator.is_main_process,
|
|
|
|
|
save_function=accelerator.save,
|
|
|
|
|
state_dict=accelerator.get_state_dict(model),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone()
|
|
|
|
|
unwrapped_model = accelerator.unwrap_model(model)
|
|
|
|
|