You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
5.1 KiB
Python

#!/usr/bin/python
# -*- encoding: utf-8 -*-
from logger import setup_logger
from model import BiSeNet
from face_dataset import FaceMask
from loss import OhemCELoss
from evaluate import evaluate
from optimizer import Optimizer
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.distributed as dist
import os
import os.path as osp
import logging
import time
import datetime
import argparse
respth = './res'
if not osp.exists(respth):
os.makedirs(respth)
logger = logging.getLogger()
def parse_args():
parse = argparse.ArgumentParser()
parse.add_argument(
'--local_rank',
dest = 'local_rank',
type = int,
default = -1,
)
return parse.parse_args()
def train():
args = parse_args()
torch.cuda.set_device(args.local_rank)
dist.init_process_group(
backend = 'nccl',
init_method = 'tcp://127.0.0.1:33241',
world_size = torch.cuda.device_count(),
rank=args.local_rank
)
setup_logger(respth)
# dataset
n_classes = 19
n_img_per_gpu = 16
n_workers = 8
cropsize = [448, 448]
data_root = '/home/zll/data/CelebAMask-HQ/'
ds = FaceMask(data_root, cropsize=cropsize, mode='train')
sampler = torch.utils.data.distributed.DistributedSampler(ds)
dl = DataLoader(ds,
batch_size = n_img_per_gpu,
shuffle = False,
sampler = sampler,
num_workers = n_workers,
pin_memory = True,
drop_last = True)
# model
ignore_idx = -100
net = BiSeNet(n_classes=n_classes)
net.cuda()
net.train()
net = nn.parallel.DistributedDataParallel(net,
device_ids = [args.local_rank, ],
output_device = args.local_rank
)
score_thres = 0.7
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
## optimizer
momentum = 0.9
weight_decay = 5e-4
lr_start = 1e-2
max_iter = 80000
power = 0.9
warmup_steps = 1000
warmup_start_lr = 1e-5
optim = Optimizer(
model = net.module,
lr0 = lr_start,
momentum = momentum,
wd = weight_decay,
warmup_steps = warmup_steps,
warmup_start_lr = warmup_start_lr,
max_iter = max_iter,
power = power)
## train loop
msg_iter = 50
loss_avg = []
st = glob_st = time.time()
diter = iter(dl)
epoch = 0
for it in range(max_iter):
try:
im, lb = next(diter)
if not im.size()[0] == n_img_per_gpu:
raise StopIteration
except StopIteration:
epoch += 1
sampler.set_epoch(epoch)
diter = iter(dl)
im, lb = next(diter)
im = im.cuda()
lb = lb.cuda()
H, W = im.size()[2:]
lb = torch.squeeze(lb, 1)
optim.zero_grad()
out, out16, out32 = net(im)
lossp = LossP(out, lb)
loss2 = Loss2(out16, lb)
loss3 = Loss3(out32, lb)
loss = lossp + loss2 + loss3
loss.backward()
optim.step()
loss_avg.append(loss.item())
# print training log message
if (it+1) % msg_iter == 0:
loss_avg = sum(loss_avg) / len(loss_avg)
lr = optim.lr
ed = time.time()
t_intv, glob_t_intv = ed - st, ed - glob_st
eta = int((max_iter - it) * (glob_t_intv / it))
eta = str(datetime.timedelta(seconds=eta))
msg = ', '.join([
'it: {it}/{max_it}',
'lr: {lr:4f}',
'loss: {loss:.4f}',
'eta: {eta}',
'time: {time:.4f}',
]).format(
it = it+1,
max_it = max_iter,
lr = lr,
loss = loss_avg,
time = t_intv,
eta = eta
)
logger.info(msg)
loss_avg = []
st = ed
if dist.get_rank() == 0:
if (it+1) % 5000 == 0:
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
if dist.get_rank() == 0:
torch.save(state, './res/cp/{}_iter.pth'.format(it))
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
# dump the final model
save_pth = osp.join(respth, 'model_final_diss.pth')
# net.cpu()
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
if dist.get_rank() == 0:
torch.save(state, save_pth)
logger.info('training done, model saved to: {}'.format(save_pth))
if __name__ == "__main__":
train()