GIMP-ML/gimp-plugins/DeblurGANv2/train.py
2020-04-27 10:02:33 +05:30

182 lines
8.0 KiB
Python

import logging
from functools import partial
import cv2
import torch
import torch.optim as optim
import tqdm
import yaml
from joblib import cpu_count
from torch.utils.data import DataLoader
from adversarial_trainer import GANFactory
from dataset import PairedDataset
from metric_counter import MetricCounter
from models.losses import get_loss
from models.models import get_model
from models.networks import get_nets
from schedulers import LinearDecay, WarmRestart
cv2.setNumThreads(0)
class Trainer:
def __init__(self, config, train: DataLoader, val: DataLoader):
self.config = config
self.train_dataset = train
self.val_dataset = val
self.adv_lambda = config['model']['adv_lambda']
self.metric_counter = MetricCounter(config['experiment_desc'])
self.warmup_epochs = config['warmup_num']
def train(self):
self._init_params()
for epoch in range(0, config['num_epochs']):
if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
self.netG.module.unfreeze()
self.optimizer_G = self._get_optim(self.netG.parameters())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self._run_epoch(epoch)
self._validate(epoch)
self.scheduler_G.step()
self.scheduler_D.step()
if self.metric_counter.update_best_model():
torch.save({
'model': self.netG.state_dict()
}, 'best_{}.h5'.format(self.config['experiment_desc']))
torch.save({
'model': self.netG.state_dict()
}, 'last_{}.h5'.format(self.config['experiment_desc']))
print(self.metric_counter.loss_message())
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
def _run_epoch(self, epoch):
self.metric_counter.clear()
for param_group in self.optimizer_G.param_groups:
lr = param_group['lr']
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_D = self._update_d(outputs, targets)
self.optimizer_G.zero_grad()
loss_content = self.criterionG(outputs, targets)
loss_adv = self.adv_trainer.loss_g(outputs, targets)
loss_G = loss_content + self.adv_lambda * loss_adv
loss_G.backward()
self.optimizer_G.step()
self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
tq.set_postfix(loss=self.metric_counter.loss_message())
if not i:
self.metric_counter.add_image(img_for_vis, tag='train')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch)
def _validate(self, epoch):
self.metric_counter.clear()
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
tq.set_description('Validation')
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_content = self.criterionG(outputs, targets)
loss_adv = self.adv_trainer.loss_g(outputs, targets)
loss_G = loss_content + self.adv_lambda * loss_adv
self.metric_counter.add_losses(loss_G.item(), loss_content.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
if not i:
self.metric_counter.add_image(img_for_vis, tag='val')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch, validation=True)
def _update_d(self, outputs, targets):
if self.config['model']['d_name'] == 'no_gan':
return 0
self.optimizer_D.zero_grad()
loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
loss_D.backward(retain_graph=True)
self.optimizer_D.step()
return loss_D.item()
def _get_optim(self, params):
if self.config['optimizer']['name'] == 'adam':
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
elif self.config['optimizer']['name'] == 'sgd':
optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
elif self.config['optimizer']['name'] == 'adadelta':
optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
else:
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
return optimizer
def _get_scheduler(self, optimizer):
if self.config['scheduler']['name'] == 'plateau':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
patience=self.config['scheduler']['patience'],
factor=self.config['scheduler']['factor'],
min_lr=self.config['scheduler']['min_lr'])
elif self.config['optimizer']['name'] == 'sgdr':
scheduler = WarmRestart(optimizer)
elif self.config['scheduler']['name'] == 'linear':
scheduler = LinearDecay(optimizer,
min_lr=self.config['scheduler']['min_lr'],
num_epochs=self.config['num_epochs'],
start_epoch=self.config['scheduler']['start_epoch'])
else:
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
return scheduler
@staticmethod
def _get_adversarial_trainer(d_name, net_d, criterion_d):
if d_name == 'no_gan':
return GANFactory.create_model('NoGAN')
elif d_name == 'patch_gan' or d_name == 'multi_scale':
return GANFactory.create_model('SingleGAN', net_d, criterion_d)
elif d_name == 'double_gan':
return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
else:
raise ValueError("Discriminator Network [%s] not recognized." % d_name)
def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.cuda()
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self.scheduler_D = self._get_scheduler(self.optimizer_D)
if __name__ == '__main__':
with open('config/config.yaml', 'r') as f:
config = yaml.load(f)
batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True)
datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()