You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
8.0 KiB
Python

import logging
from functools import partial
import cv2
import torch
import torch.optim as optim
import tqdm
import yaml
from joblib import cpu_count
from torch.utils.data import DataLoader
from adversarial_trainer import GANFactory
from dataset import PairedDataset
from metric_counter import MetricCounter
from models.losses import get_loss
from models.models import get_model
from models.networks import get_nets
from schedulers import LinearDecay, WarmRestart
cv2.setNumThreads(0)
class Trainer:
def __init__(self, config, train: DataLoader, val: DataLoader):
self.config = config
self.train_dataset = train
self.val_dataset = val
self.adv_lambda = config['model']['adv_lambda']
self.metric_counter = MetricCounter(config['experiment_desc'])
self.warmup_epochs = config['warmup_num']
def train(self):
self._init_params()
for epoch in range(0, config['num_epochs']):
if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
self.netG.module.unfreeze()
self.optimizer_G = self._get_optim(self.netG.parameters())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self._run_epoch(epoch)
self._validate(epoch)
self.scheduler_G.step()
self.scheduler_D.step()
if self.metric_counter.update_best_model():
torch.save({
'model': self.netG.state_dict()
}, 'best_{}.h5'.format(self.config['experiment_desc']))
torch.save({
'model': self.netG.state_dict()
}, 'last_{}.h5'.format(self.config['experiment_desc']))
print(self.metric_counter.loss_message())
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
def _run_epoch(self, epoch):
self.metric_counter.clear()
for param_group in self.optimizer_G.param_groups:
lr = param_group['lr']
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_D = self._update_d(outputs, targets)
self.optimizer_G.zero_grad()
loss_content = self.criterionG(outputs, targets)
loss_adv = self.adv_trainer.loss_g(outputs, targets)
loss_G = loss_content + self.adv_lambda * loss_adv
loss_G.backward()
self.optimizer_G.step()
self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
tq.set_postfix(loss=self.metric_counter.loss_message())
if not i:
self.metric_counter.add_image(img_for_vis, tag='train')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch)
def _validate(self, epoch):
self.metric_counter.clear()
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
tq.set_description('Validation')
i = 0
for data in tq:
inputs, targets = self.model.get_input(data)
outputs = self.netG(inputs)
loss_content = self.criterionG(outputs, targets)
loss_adv = self.adv_trainer.loss_g(outputs, targets)
loss_G = loss_content + self.adv_lambda * loss_adv
self.metric_counter.add_losses(loss_G.item(), loss_content.item())
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
if not i:
self.metric_counter.add_image(img_for_vis, tag='val')
i += 1
if i > epoch_size:
break
tq.close()
self.metric_counter.write_to_tensorboard(epoch, validation=True)
def _update_d(self, outputs, targets):
if self.config['model']['d_name'] == 'no_gan':
return 0
self.optimizer_D.zero_grad()
loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
loss_D.backward(retain_graph=True)
self.optimizer_D.step()
return loss_D.item()
def _get_optim(self, params):
if self.config['optimizer']['name'] == 'adam':
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
elif self.config['optimizer']['name'] == 'sgd':
optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
elif self.config['optimizer']['name'] == 'adadelta':
optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
else:
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
return optimizer
def _get_scheduler(self, optimizer):
if self.config['scheduler']['name'] == 'plateau':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
patience=self.config['scheduler']['patience'],
factor=self.config['scheduler']['factor'],
min_lr=self.config['scheduler']['min_lr'])
elif self.config['optimizer']['name'] == 'sgdr':
scheduler = WarmRestart(optimizer)
elif self.config['scheduler']['name'] == 'linear':
scheduler = LinearDecay(optimizer,
min_lr=self.config['scheduler']['min_lr'],
num_epochs=self.config['num_epochs'],
start_epoch=self.config['scheduler']['start_epoch'])
else:
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
return scheduler
@staticmethod
def _get_adversarial_trainer(d_name, net_d, criterion_d):
if d_name == 'no_gan':
return GANFactory.create_model('NoGAN')
elif d_name == 'patch_gan' or d_name == 'multi_scale':
return GANFactory.create_model('SingleGAN', net_d, criterion_d)
elif d_name == 'double_gan':
return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
else:
raise ValueError("Discriminator Network [%s] not recognized." % d_name)
def _init_params(self):
self.criterionG, criterionD = get_loss(self.config['model'])
self.netG, netD = get_nets(self.config['model'])
self.netG.cuda()
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
self.model = get_model(self.config['model'])
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
self.scheduler_G = self._get_scheduler(self.optimizer_G)
self.scheduler_D = self._get_scheduler(self.optimizer_D)
if __name__ == '__main__':
with open('config/config.yaml', 'r') as f:
config = yaml.load(f)
batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True)
datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()