You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

819 lines
32 KiB
Python

### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
else:
raise('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
def define_VAE(input_nc, gpu_ids=[]):
netVAE = VAE(19, 32, 32, 1024)
print(netVAE)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netVAE.cuda(gpu_ids[0])
return netVAE
def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
print(netB)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netB.cuda(gpu_ids[0])
netB.apply(weights_init)
return netB
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
##############################################################################
# Generator
##############################################################################
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
# style encoder
self.enc_style = StyleEncoder(5, 3, 16, self.get_num_adain_params(self.model), norm='none', activ='relu', pad_type='reflect')
# label encoder
self.enc_label = LabelEncoder(5, 19, 16, 64, norm='none', activ='relu', pad_type='reflect')
def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
mean = adain_params[:, :m.num_features]
std = adain_params[:, m.num_features:2*m.num_features]
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
if adain_params.size(1) > 2*m.num_features:
adain_params = adain_params[:, 2*m.num_features:]
def get_num_adain_params(self, model):
# return the number of AdaIN parameters needed by the model
num_adain_params = 0
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2*m.num_features
return num_adain_params
def forward(self, input, input_ref, image_ref):
fea1, fea2 = self.enc_label(input_ref)
adain_params = self.enc_style((image_ref, fea1, fea2))
self.assign_adain_params(adain_params, self.model)
return self.model(input)
class BlendGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(BlendGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
self.model = nn.Sequential(*model)
def forward(self, input1, input2):
m = self.model(torch.cat([input1, input2], 1))
return input1 * m + input2 * (1-m), m
# Define the Multiscale Discriminator.
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Define the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
# Define the MaskVAE
class VAE(nn.Module):
def __init__(self, nc, ngf, ndf, latent_variable_size):
super(VAE, self).__init__()
#self.cuda = True
self.nc = nc
self.ngf = ngf
self.ndf = ndf
self.latent_variable_size = latent_variable_size
# encoder
self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
self.bn1 = nn.BatchNorm2d(ndf)
self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(ndf*2)
self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(ndf*4)
self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(ndf*8)
self.e5 = nn.Conv2d(ndf*8, ndf*16, 4, 2, 1)
self.bn5 = nn.BatchNorm2d(ndf*16)
self.e6 = nn.Conv2d(ndf*16, ndf*32, 4, 2, 1)
self.bn6 = nn.BatchNorm2d(ndf*32)
self.e7 = nn.Conv2d(ndf*32, ndf*64, 4, 2, 1)
self.bn7 = nn.BatchNorm2d(ndf*64)
self.fc1 = nn.Linear(ndf*64*4*4, latent_variable_size)
self.fc2 = nn.Linear(ndf*64*4*4, latent_variable_size)
# decoder
self.d1 = nn.Linear(latent_variable_size, ngf*64*4*4)
self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd1 = nn.ReplicationPad2d(1)
self.d2 = nn.Conv2d(ngf*64, ngf*32, 3, 1)
self.bn8 = nn.BatchNorm2d(ngf*32, 1.e-3)
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd2 = nn.ReplicationPad2d(1)
self.d3 = nn.Conv2d(ngf*32, ngf*16, 3, 1)
self.bn9 = nn.BatchNorm2d(ngf*16, 1.e-3)
self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd3 = nn.ReplicationPad2d(1)
self.d4 = nn.Conv2d(ngf*16, ngf*8, 3, 1)
self.bn10 = nn.BatchNorm2d(ngf*8, 1.e-3)
self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd4 = nn.ReplicationPad2d(1)
self.d5 = nn.Conv2d(ngf*8, ngf*4, 3, 1)
self.bn11 = nn.BatchNorm2d(ngf*4, 1.e-3)
self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd5 = nn.ReplicationPad2d(1)
self.d6 = nn.Conv2d(ngf*4, ngf*2, 3, 1)
self.bn12 = nn.BatchNorm2d(ngf*2, 1.e-3)
self.up6 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd6 = nn.ReplicationPad2d(1)
self.d7 = nn.Conv2d(ngf*2, ngf, 3, 1)
self.bn13 = nn.BatchNorm2d(ngf, 1.e-3)
self.up7 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd7 = nn.ReplicationPad2d(1)
self.d8 = nn.Conv2d(ngf, nc, 3, 1)
self.leakyrelu = nn.LeakyReLU(0.2)
self.relu = nn.ReLU()
#self.sigmoid = nn.Sigmoid()
self.maxpool = nn.MaxPool2d((2, 2), (2, 2))
def encode(self, x):
h1 = self.leakyrelu(self.bn1(self.e1(x)))
h2 = self.leakyrelu(self.bn2(self.e2(h1)))
h3 = self.leakyrelu(self.bn3(self.e3(h2)))
h4 = self.leakyrelu(self.bn4(self.e4(h3)))
h5 = self.leakyrelu(self.bn5(self.e5(h4)))
h6 = self.leakyrelu(self.bn6(self.e6(h5)))
h7 = self.leakyrelu(self.bn7(self.e7(h6)))
h7 = h7.view(-1, self.ndf*64*4*4)
return self.fc1(h7), self.fc2(h7)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
#if self.cuda:
eps = torch.cuda.FloatTensor(std.size()).normal_()
#else:
# eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h1 = self.relu(self.d1(z))
h1 = h1.view(-1, self.ngf*64, 4, 4)
h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1)))))
h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2)))))
h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3)))))
h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4)))))
h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5)))))
h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6)))))
return self.d8(self.pd7(self.up7(h7)))
def get_latent_var(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return z, mu, logvar.mul(0.5).exp_()
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
res = self.decode(z)
return res, x, mu, logvar
# style encode part
class StyleEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(StyleEncoder, self).__init__()
self.model = []
self.model_middle = []
self.model_last = []
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
for i in range(2):
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 2):
self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.model_middle = nn.Sequential(*self.model_middle)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
self.sft1 = SFTLayer()
self.sft2 = SFTLayer()
def forward(self, x):
fea = self.model(x[0])
fea = self.sft1((fea, x[1]))
fea = self.model_middle(fea)
fea = self.sft2((fea, x[2]))
return self.model_last(fea)
# label encode part
class LabelEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(LabelEncoder, self).__init__()
self.model = []
self.model_last = [nn.ReLU()]
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 3):
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
def forward(self, x):
fea = self.model(x)
return fea, self.model_last(fea)
# Define the basic block
class ConvBlock(nn.Module):
def __init__(self, input_dim ,output_dim, kernel_size, stride,
padding=0, norm='none', activation='relu', pad_type='zero'):
super(ConvBlock, self).__init__()
self.use_bias = True
# initialize padding
if pad_type == 'reflect':
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm2d(norm_dim)
elif norm == 'in':
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
self.norm = nn.InstanceNorm2d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'adain':
self.norm = AdaptiveInstanceNorm2d(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# initialize convolution
if norm == 'sn':
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
else:
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
def forward(self, x):
x = self.conv(self.pad(x))
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout)
def build_conv_block(self, dim, norm_type, padding_type, use_dropout):
conv_block = []
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)]
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class SFTLayer(nn.Module):
def __init__(self):
super(SFTLayer, self).__init__()
self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1)
def forward(self, x):
scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True))
return x[0] * scale + shift
class ConvBlock_SFT(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return (x[0] + fea, x[1])
class ConvBlock_SFT_last(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT_last, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return x[0] + fea
# Definition of normalization layer
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(AdaptiveInstanceNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None
self.bias = None
# just dummy buffers, not used
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
b, c = x.size(0), x.size(1)
running_mean = self.running_mean.repeat(b)
running_var = self.running_var.repeat(b)
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
out = F.batch_norm(
x_reshaped, running_mean, running_var, self.weight, self.bias,
True, self.momentum, self.eps)
return out.view(b, c, *x.size()[2:])
def __repr__(self):
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super(LayerNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
# print(x.size())
if x.size(0) == 1:
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
mean = x.view(-1).mean().view(*shape)
std = x.view(-1).std().view(*shape)
else:
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
x = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
"""
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
"""
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = nn.Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)