You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GIMP-ML/gimp-plugins/DeblurGANv2/models/losses.py

301 lines
10 KiB
Python

import torch
import torch.autograd as autograd
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from util.image_pool import ImagePool
###############################################################################
# Functions
###############################################################################
class ContentLoss():
def initialize(self, loss):
self.criterion = loss
def get_loss(self, fakeIm, realIm):
return self.criterion(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class PerceptualLoss():
def contentFunc(self):
conv_3_3_layer = 14
cnn = models.vgg19(pretrained=True).features
cnn = cnn.cuda()
model = nn.Sequential()
model = model.cuda()
model = model.eval()
for i, layer in enumerate(list(cnn)):
model.add_module(str(i), layer)
if i == conv_3_3_layer:
break
return model
def initialize(self, loss):
with torch.no_grad():
self.criterion = loss
self.contentFunc = self.contentFunc()
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def get_loss(self, fakeIm, realIm):
fakeIm = (fakeIm + 1) / 2.0
realIm = (realIm + 1) / 2.0
fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
f_fake = self.contentFunc.forward(fakeIm)
f_real = self.contentFunc.forward(realIm)
f_real_no_grad = f_real.detach()
loss = self.criterion(f_fake, f_real_no_grad)
return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class GANLoss(nn.Module):
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_l1:
self.loss = nn.L1Loss()
else:
self.loss = nn.BCEWithLogitsLoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor.cuda()
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
class DiscLoss(nn.Module):
def name(self):
return 'DiscLoss'
def __init__(self):
super(DiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_AB_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
pred_fake = net.forward(fakeB)
return self.criterionGAN(pred_fake, 1)
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.pred_fake = net.forward(fakeB.detach())
self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
# Real
self.pred_real = net.forward(realB)
self.loss_D_real = self.criterionGAN(self.pred_real, 1)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLoss(nn.Module):
def name(self):
return 'RelativisticDiscLoss'
def __init__(self):
super(RelativisticDiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLossLS(nn.Module):
def name(self):
return 'RelativisticDiscLossLS'
def __init__(self):
super(RelativisticDiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class DiscLossLS(DiscLoss):
def name(self):
return 'DiscLossLS'
def __init__(self):
super(DiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
def get_g_loss(self, net, fakeB, realB):
return DiscLoss.get_g_loss(self, net, fakeB)
def get_loss(self, net, fakeB, realB):
return DiscLoss.get_loss(self, net, fakeB, realB)
class DiscLossWGANGP(DiscLossLS):
def name(self):
return 'DiscLossWGAN-GP'
def __init__(self):
super(DiscLossWGANGP, self).__init__()
self.LAMBDA = 10
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.D_fake = net.forward(fakeB)
return -self.D_fake.mean()
def calc_gradient_penalty(self, netD, real_data, fake_data):
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda()
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD.forward(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
return gradient_penalty
def get_loss(self, net, fakeB, realB):
self.D_fake = net.forward(fakeB.detach())
self.D_fake = self.D_fake.mean()
# Real
self.D_real = net.forward(realB)
self.D_real = self.D_real.mean()
# Combined loss
self.loss_D = self.D_fake - self.D_real
gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
return self.loss_D + gradient_penalty
def get_loss(model):
if model['content_loss'] == 'perceptual':
content_loss = PerceptualLoss()
content_loss.initialize(nn.MSELoss())
elif model['content_loss'] == 'l1':
content_loss = ContentLoss()
content_loss.initialize(nn.L1Loss())
else:
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
if model['disc_loss'] == 'wgan-gp':
disc_loss = DiscLossWGANGP()
elif model['disc_loss'] == 'lsgan':
disc_loss = DiscLossLS()
elif model['disc_loss'] == 'gan':
disc_loss = DiscLoss()
elif model['disc_loss'] == 'ragan':
disc_loss = RelativisticDiscLoss()
elif model['disc_loss'] == 'ragan-ls':
disc_loss = RelativisticDiscLossLS()
else:
raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
return content_loss, disc_loss