import logging from functools import partial import cv2 import torch import torch.optim as optim import tqdm import yaml from joblib import cpu_count from torch.utils.data import DataLoader from adversarial_trainer import GANFactory from dataset import PairedDataset from metric_counter import MetricCounter from models.losses import get_loss from models.models import get_model from models.networks import get_nets from schedulers import LinearDecay, WarmRestart cv2.setNumThreads(0) class Trainer: def __init__(self, config, train: DataLoader, val: DataLoader): self.config = config self.train_dataset = train self.val_dataset = val self.adv_lambda = config['model']['adv_lambda'] self.metric_counter = MetricCounter(config['experiment_desc']) self.warmup_epochs = config['warmup_num'] def train(self): self._init_params() for epoch in range(0, config['num_epochs']): if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0): self.netG.module.unfreeze() self.optimizer_G = self._get_optim(self.netG.parameters()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self._run_epoch(epoch) self._validate(epoch) self.scheduler_G.step() self.scheduler_D.step() if self.metric_counter.update_best_model(): torch.save({ 'model': self.netG.state_dict() }, 'best_{}.h5'.format(self.config['experiment_desc'])) torch.save({ 'model': self.netG.state_dict() }, 'last_{}.h5'.format(self.config['experiment_desc'])) print(self.metric_counter.loss_message()) logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % ( self.config['experiment_desc'], epoch, self.metric_counter.loss_message())) def _run_epoch(self, epoch): self.metric_counter.clear() for param_group in self.optimizer_G.param_groups: lr = param_group['lr'] epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset) tq = tqdm.tqdm(self.train_dataset, total=epoch_size) tq.set_description('Epoch {}, lr {}'.format(epoch, lr)) i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_D = self._update_d(outputs, targets) self.optimizer_G.zero_grad() loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv loss_G.backward() self.optimizer_G.step() self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) tq.set_postfix(loss=self.metric_counter.loss_message()) if not i: self.metric_counter.add_image(img_for_vis, tag='train') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch) def _validate(self, epoch): self.metric_counter.clear() epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset) tq = tqdm.tqdm(self.val_dataset, total=epoch_size) tq.set_description('Validation') i = 0 for data in tq: inputs, targets = self.model.get_input(data) outputs = self.netG(inputs) loss_content = self.criterionG(outputs, targets) loss_adv = self.adv_trainer.loss_g(outputs, targets) loss_G = loss_content + self.adv_lambda * loss_adv self.metric_counter.add_losses(loss_G.item(), loss_content.item()) curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets) self.metric_counter.add_metrics(curr_psnr, curr_ssim) if not i: self.metric_counter.add_image(img_for_vis, tag='val') i += 1 if i > epoch_size: break tq.close() self.metric_counter.write_to_tensorboard(epoch, validation=True) def _update_d(self, outputs, targets): if self.config['model']['d_name'] == 'no_gan': return 0 self.optimizer_D.zero_grad() loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets) loss_D.backward(retain_graph=True) self.optimizer_D.step() return loss_D.item() def _get_optim(self, params): if self.config['optimizer']['name'] == 'adam': optimizer = optim.Adam(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'sgd': optimizer = optim.SGD(params, lr=self.config['optimizer']['lr']) elif self.config['optimizer']['name'] == 'adadelta': optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr']) else: raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name']) return optimizer def _get_scheduler(self, optimizer): if self.config['scheduler']['name'] == 'plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=self.config['scheduler']['patience'], factor=self.config['scheduler']['factor'], min_lr=self.config['scheduler']['min_lr']) elif self.config['optimizer']['name'] == 'sgdr': scheduler = WarmRestart(optimizer) elif self.config['scheduler']['name'] == 'linear': scheduler = LinearDecay(optimizer, min_lr=self.config['scheduler']['min_lr'], num_epochs=self.config['num_epochs'], start_epoch=self.config['scheduler']['start_epoch']) else: raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name']) return scheduler @staticmethod def _get_adversarial_trainer(d_name, net_d, criterion_d): if d_name == 'no_gan': return GANFactory.create_model('NoGAN') elif d_name == 'patch_gan' or d_name == 'multi_scale': return GANFactory.create_model('SingleGAN', net_d, criterion_d) elif d_name == 'double_gan': return GANFactory.create_model('DoubleGAN', net_d, criterion_d) else: raise ValueError("Discriminator Network [%s] not recognized." % d_name) def _init_params(self): self.criterionG, criterionD = get_loss(self.config['model']) self.netG, netD = get_nets(self.config['model']) self.netG.cuda() self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD) self.model = get_model(self.config['model']) self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters())) self.optimizer_D = self._get_optim(self.adv_trainer.get_params()) self.scheduler_G = self._get_scheduler(self.optimizer_G) self.scheduler_D = self._get_scheduler(self.optimizer_D) if __name__ == '__main__': with open('config/config.yaml', 'r') as f: config = yaml.load(f) batch_size = config.pop('batch_size') get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True) datasets = map(config.pop, ('train', 'val')) datasets = map(PairedDataset.from_config, datasets) train, val = map(get_dataloader, datasets) trainer = Trainer(config, train=train, val=val) trainer.train()