You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

301 lines
10 KiB
Python

import torch
import torch.autograd as autograd
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from util.image_pool import ImagePool
###############################################################################
# Functions
###############################################################################
class ContentLoss():
def initialize(self, loss):
self.criterion = loss
def get_loss(self, fakeIm, realIm):
return self.criterion(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class PerceptualLoss():
def contentFunc(self):
conv_3_3_layer = 14
cnn = models.vgg19(pretrained=True).features
cnn = cnn.cuda()
model = nn.Sequential()
model = model.cuda()
model = model.eval()
for i, layer in enumerate(list(cnn)):
model.add_module(str(i), layer)
if i == conv_3_3_layer:
break
return model
def initialize(self, loss):
with torch.no_grad():
self.criterion = loss
self.contentFunc = self.contentFunc()
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def get_loss(self, fakeIm, realIm):
fakeIm = (fakeIm + 1) / 2.0
realIm = (realIm + 1) / 2.0
fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
f_fake = self.contentFunc.forward(fakeIm)
f_real = self.contentFunc.forward(realIm)
f_real_no_grad = f_real.detach()
loss = self.criterion(f_fake, f_real_no_grad)
return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class GANLoss(nn.Module):
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_l1:
self.loss = nn.L1Loss()
else:
self.loss = nn.BCEWithLogitsLoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor.cuda()
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
class DiscLoss(nn.Module):
def name(self):
return 'DiscLoss'
def __init__(self):
super(DiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_AB_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
pred_fake = net.forward(fakeB)
return self.criterionGAN(pred_fake, 1)
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.pred_fake = net.forward(fakeB.detach())
self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
# Real
self.pred_real = net.forward(realB)
self.loss_D_real = self.criterionGAN(self.pred_real, 1)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLoss(nn.Module):
def name(self):
return 'RelativisticDiscLoss'
def __init__(self):
super(RelativisticDiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLossLS(nn.Module):
def name(self):
return 'RelativisticDiscLossLS'
def __init__(self):
super(RelativisticDiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class DiscLossLS(DiscLoss):
def name(self):
return 'DiscLossLS'
def __init__(self):
super(DiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
def get_g_loss(self, net, fakeB, realB):
return DiscLoss.get_g_loss(self, net, fakeB)
def get_loss(self, net, fakeB, realB):
return DiscLoss.get_loss(self, net, fakeB, realB)
class DiscLossWGANGP(DiscLossLS):
def name(self):
return 'DiscLossWGAN-GP'
def __init__(self):
super(DiscLossWGANGP, self).__init__()
self.LAMBDA = 10
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.D_fake = net.forward(fakeB)
return -self.D_fake.mean()
def calc_gradient_penalty(self, netD, real_data, fake_data):
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda()
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD.forward(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
return gradient_penalty
def get_loss(self, net, fakeB, realB):
self.D_fake = net.forward(fakeB.detach())
self.D_fake = self.D_fake.mean()
# Real
self.D_real = net.forward(realB)
self.D_real = self.D_real.mean()
# Combined loss
self.loss_D = self.D_fake - self.D_real
gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
return self.loss_D + gradient_penalty
def get_loss(model):
if model['content_loss'] == 'perceptual':
content_loss = PerceptualLoss()
content_loss.initialize(nn.MSELoss())
elif model['content_loss'] == 'l1':
content_loss = ContentLoss()
content_loss.initialize(nn.L1Loss())
else:
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
if model['disc_loss'] == 'wgan-gp':
disc_loss = DiscLossWGANGP()
elif model['disc_loss'] == 'lsgan':
disc_loss = DiscLossLS()
elif model['disc_loss'] == 'gan':
disc_loss = DiscLoss()
elif model['disc_loss'] == 'ragan':
disc_loss = RelativisticDiscLoss()
elif model['disc_loss'] == 'ragan-ls':
disc_loss = RelativisticDiscLossLS()
else:
raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
return content_loss, disc_loss