mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-10-31 09:20:18 +00:00
1182 lines
47 KiB
Python
1182 lines
47 KiB
Python
import torch
|
|
import os
|
|
import math
|
|
import torch.nn as nn
|
|
from torch.nn import init
|
|
import functools
|
|
from torch.autograd import Variable
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
# from torch.utils.serialization import load_lua
|
|
from lib.nn import SynchronizedBatchNorm2d as SynBN2d
|
|
###############################################################################
|
|
# Functions
|
|
###############################################################################
|
|
|
|
def pad_tensor(input):
|
|
|
|
height_org, width_org = input.shape[2], input.shape[3]
|
|
divide = 16
|
|
|
|
if width_org % divide != 0 or height_org % divide != 0:
|
|
|
|
width_res = width_org % divide
|
|
height_res = height_org % divide
|
|
if width_res != 0:
|
|
width_div = divide - width_res
|
|
pad_left = int(width_div / 2)
|
|
pad_right = int(width_div - pad_left)
|
|
else:
|
|
pad_left = 0
|
|
pad_right = 0
|
|
|
|
if height_res != 0:
|
|
height_div = divide - height_res
|
|
pad_top = int(height_div / 2)
|
|
pad_bottom = int(height_div - pad_top)
|
|
else:
|
|
pad_top = 0
|
|
pad_bottom = 0
|
|
|
|
padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom))
|
|
input = padding(input)
|
|
else:
|
|
pad_left = 0
|
|
pad_right = 0
|
|
pad_top = 0
|
|
pad_bottom = 0
|
|
|
|
height, width = input.data.shape[2], input.data.shape[3]
|
|
assert width % divide == 0, 'width cant divided by stride'
|
|
assert height % divide == 0, 'height cant divided by stride'
|
|
|
|
return input, pad_left, pad_right, pad_top, pad_bottom
|
|
|
|
def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
|
|
height, width = input.shape[2], input.shape[3]
|
|
return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right]
|
|
|
|
def weights_init(m):
|
|
classname = m.__class__.__name__
|
|
if classname.find('Conv') != -1:
|
|
m.weight.data.normal_(0.0, 0.02)
|
|
elif classname.find('BatchNorm2d') != -1:
|
|
m.weight.data.normal_(1.0, 0.02)
|
|
m.bias.data.fill_(0)
|
|
|
|
|
|
def get_norm_layer(norm_type='instance'):
|
|
if norm_type == 'batch':
|
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
|
elif norm_type == 'instance':
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
|
elif norm_type == 'synBN':
|
|
norm_layer = functools.partial(SynBN2d, affine=True)
|
|
else:
|
|
raise NotImplementedError('normalization layer [%s] is not found' % norm)
|
|
return norm_layer
|
|
|
|
|
|
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None):
|
|
netG = None
|
|
use_gpu = len(gpu_ids) > 0
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
|
|
# if use_gpu:
|
|
# assert(torch.cuda.is_available())
|
|
|
|
if which_model_netG == 'resnet_9blocks':
|
|
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
|
|
elif which_model_netG == 'resnet_6blocks':
|
|
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
|
|
elif which_model_netG == 'unet_128':
|
|
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
|
|
elif which_model_netG == 'unet_256':
|
|
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
|
|
elif which_model_netG == 'unet_512':
|
|
netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
|
|
elif which_model_netG == 'sid_unet':
|
|
netG = Unet(opt, skip)
|
|
elif which_model_netG == 'sid_unet_shuffle':
|
|
netG = Unet_pixelshuffle(opt, skip)
|
|
elif which_model_netG == 'sid_unet_resize':
|
|
netG = Unet_resize_conv(opt, skip)
|
|
elif which_model_netG == 'DnCNN':
|
|
netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3)
|
|
else:
|
|
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
|
|
if torch.cuda.is_available() and not opt.cFlag:
|
|
netG.cuda(device=gpu_ids[0])
|
|
# netG = torch.nn.DataParallel(netG, gpu_ids)
|
|
netG.apply(weights_init)
|
|
return netG
|
|
|
|
|
|
def define_D(input_nc, ndf, which_model_netD,
|
|
n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False):
|
|
netD = None
|
|
use_gpu = len(gpu_ids) > 0
|
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
|
|
if use_gpu:
|
|
assert(torch.cuda.is_available())
|
|
if which_model_netD == 'basic':
|
|
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
|
|
elif which_model_netD == 'n_layers':
|
|
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
|
|
elif which_model_netD == 'no_norm':
|
|
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
|
|
elif which_model_netD == 'no_norm_4':
|
|
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
|
|
elif which_model_netD == 'no_patchgan':
|
|
netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch)
|
|
else:
|
|
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
|
|
which_model_netD)
|
|
if use_gpu:
|
|
netD.cuda(device=gpu_ids[0])
|
|
netD = torch.nn.DataParallel(netD, gpu_ids)
|
|
netD.apply(weights_init)
|
|
return netD
|
|
|
|
|
|
def print_network(net):
|
|
num_params = 0
|
|
for param in net.parameters():
|
|
num_params += param.numel()
|
|
print(net)
|
|
print('Total number of parameters: %d' % num_params)
|
|
|
|
|
|
##############################################################################
|
|
# Classes
|
|
##############################################################################
|
|
|
|
|
|
# Defines the GAN loss which uses either LSGAN or the regular GAN.
|
|
# When LSGAN is used, it is basically same as MSELoss,
|
|
# but it abstracts away the need to create the target label tensor
|
|
# that has the same size as the input
|
|
class GANLoss(nn.Module):
|
|
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
|
|
tensor=torch.FloatTensor):
|
|
super(GANLoss, self).__init__()
|
|
self.real_label = target_real_label
|
|
self.fake_label = target_fake_label
|
|
self.real_label_var = None
|
|
self.fake_label_var = None
|
|
self.Tensor = tensor
|
|
if use_lsgan:
|
|
self.loss = nn.MSELoss()
|
|
else:
|
|
self.loss = nn.BCELoss()
|
|
|
|
def get_target_tensor(self, input, target_is_real):
|
|
target_tensor = None
|
|
if target_is_real:
|
|
create_label = ((self.real_label_var is None) or
|
|
(self.real_label_var.numel() != input.numel()))
|
|
if create_label:
|
|
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
|
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
|
target_tensor = self.real_label_var
|
|
else:
|
|
create_label = ((self.fake_label_var is None) or
|
|
(self.fake_label_var.numel() != input.numel()))
|
|
if create_label:
|
|
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
|
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
|
target_tensor = self.fake_label_var
|
|
return target_tensor
|
|
|
|
def __call__(self, input, target_is_real):
|
|
target_tensor = self.get_target_tensor(input, target_is_real)
|
|
return self.loss(input, target_tensor)
|
|
|
|
|
|
|
|
class DiscLossWGANGP():
|
|
def __init__(self):
|
|
self.LAMBDA = 10
|
|
|
|
def name(self):
|
|
return 'DiscLossWGAN-GP'
|
|
|
|
def initialize(self, opt, tensor):
|
|
# DiscLossLS.initialize(self, opt, tensor)
|
|
self.LAMBDA = 10
|
|
|
|
# def get_g_loss(self, net, realA, fakeB):
|
|
# # First, G(A) should fake the discriminator
|
|
# self.D_fake = net.forward(fakeB)
|
|
# return -self.D_fake.mean()
|
|
|
|
def calc_gradient_penalty(self, netD, real_data, fake_data):
|
|
alpha = torch.rand(1, 1)
|
|
alpha = alpha.expand(real_data.size())
|
|
alpha = alpha.cuda()
|
|
|
|
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
|
|
|
|
interpolates = interpolates.cuda()
|
|
interpolates = Variable(interpolates, requires_grad=True)
|
|
|
|
disc_interpolates = netD.forward(interpolates)
|
|
|
|
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
|
|
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
|
|
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
|
|
|
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
|
|
return gradient_penalty
|
|
|
|
# Defines the generator that consists of Resnet blocks between a few
|
|
# downsampling/upsampling operations.
|
|
# Code and idea originally from Justin Johnson's architecture.
|
|
# https://github.com/jcjohnson/fast-neural-style/
|
|
class ResnetGenerator(nn.Module):
|
|
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
|
|
assert(n_blocks >= 0)
|
|
super(ResnetGenerator, self).__init__()
|
|
self.input_nc = input_nc
|
|
self.output_nc = output_nc
|
|
self.ngf = ngf
|
|
self.gpu_ids = gpu_ids
|
|
|
|
model = [nn.ReflectionPad2d(3),
|
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
|
norm_layer(ngf),
|
|
nn.ReLU(True)]
|
|
|
|
n_downsampling = 2
|
|
for i in range(n_downsampling):
|
|
mult = 2**i
|
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
|
stride=2, padding=1),
|
|
norm_layer(ngf * mult * 2),
|
|
nn.ReLU(True)]
|
|
|
|
mult = 2**n_downsampling
|
|
for i in range(n_blocks):
|
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
|
|
|
|
for i in range(n_downsampling):
|
|
mult = 2**(n_downsampling - i)
|
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
|
kernel_size=3, stride=2,
|
|
padding=1, output_padding=1),
|
|
norm_layer(int(ngf * mult / 2)),
|
|
nn.ReLU(True)]
|
|
model += [nn.ReflectionPad2d(3)]
|
|
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
|
model += [nn.Tanh()]
|
|
|
|
self.model = nn.Sequential(*model)
|
|
|
|
def forward(self, input):
|
|
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
|
|
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
|
|
else:
|
|
return self.model(input)
|
|
|
|
|
|
# Define a resnet block
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, dim, padding_type, norm_layer, use_dropout):
|
|
super(ResnetBlock, self).__init__()
|
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
|
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
|
|
conv_block = []
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
|
norm_layer(dim),
|
|
nn.ReLU(True)]
|
|
if use_dropout:
|
|
conv_block += [nn.Dropout(0.5)]
|
|
|
|
p = 0
|
|
if padding_type == 'reflect':
|
|
conv_block += [nn.ReflectionPad2d(1)]
|
|
elif padding_type == 'replicate':
|
|
conv_block += [nn.ReplicationPad2d(1)]
|
|
elif padding_type == 'zero':
|
|
p = 1
|
|
else:
|
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
|
norm_layer(dim)]
|
|
|
|
return nn.Sequential(*conv_block)
|
|
|
|
def forward(self, x):
|
|
out = x + self.conv_block(x)
|
|
return out
|
|
|
|
|
|
# Defines the Unet generator.
|
|
# |num_downs|: number of downsamplings in UNet. For example,
|
|
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
|
# at the bottleneck
|
|
class UnetGenerator(nn.Module):
|
|
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
|
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None):
|
|
super(UnetGenerator, self).__init__()
|
|
self.gpu_ids = gpu_ids
|
|
self.opt = opt
|
|
# currently support only input_nc == output_nc
|
|
assert(input_nc == output_nc)
|
|
|
|
# construct unet structure
|
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt)
|
|
for i in range(num_downs - 5):
|
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt)
|
|
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt)
|
|
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt)
|
|
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt)
|
|
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt)
|
|
|
|
if skip == True:
|
|
skipmodule = SkipModule(unet_block, opt)
|
|
self.model = skipmodule
|
|
else:
|
|
self.model = unet_block
|
|
|
|
def forward(self, input):
|
|
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
|
|
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
|
|
else:
|
|
return self.model(input)
|
|
|
|
class SkipModule(nn.Module):
|
|
def __init__(self, submodule, opt):
|
|
super(SkipModule, self).__init__()
|
|
self.submodule = submodule
|
|
self.opt = opt
|
|
|
|
def forward(self, x):
|
|
latent = self.submodule(x)
|
|
return self.opt.skip*x + latent, latent
|
|
|
|
|
|
|
|
# Defines the submodule with skip connection.
|
|
# X -------------------identity---------------------- X
|
|
# |-- downsampling -- |submodule| -- upsampling --|
|
|
class UnetSkipConnectionBlock(nn.Module):
|
|
def __init__(self, outer_nc, inner_nc,
|
|
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None):
|
|
super(UnetSkipConnectionBlock, self).__init__()
|
|
self.outermost = outermost
|
|
|
|
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
|
|
stride=2, padding=1)
|
|
downrelu = nn.LeakyReLU(0.2, True)
|
|
downnorm = norm_layer(inner_nc)
|
|
uprelu = nn.ReLU(True)
|
|
upnorm = norm_layer(outer_nc)
|
|
|
|
if opt.use_norm == 0:
|
|
if outermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downconv]
|
|
up = [uprelu, upconv, nn.Tanh()]
|
|
model = down + [submodule] + up
|
|
elif innermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downrelu, downconv]
|
|
up = [uprelu, upconv]
|
|
model = down + up
|
|
else:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downrelu, downconv]
|
|
up = [uprelu, upconv]
|
|
|
|
if use_dropout:
|
|
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
|
else:
|
|
model = down + [submodule] + up
|
|
else:
|
|
if outermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downconv]
|
|
up = [uprelu, upconv, nn.Tanh()]
|
|
model = down + [submodule] + up
|
|
elif innermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downrelu, downconv]
|
|
up = [uprelu, upconv, upnorm]
|
|
model = down + up
|
|
else:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downrelu, downconv, downnorm]
|
|
up = [uprelu, upconv, upnorm]
|
|
|
|
if use_dropout:
|
|
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
|
else:
|
|
model = down + [submodule] + up
|
|
|
|
self.model = nn.Sequential(*model)
|
|
|
|
def forward(self, x):
|
|
if self.outermost:
|
|
return self.model(x)
|
|
else:
|
|
return torch.cat([self.model(x), x], 1)
|
|
|
|
|
|
# Defines the PatchGAN discriminator with the specified arguments.
|
|
class NLayerDiscriminator(nn.Module):
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
|
|
super(NLayerDiscriminator, self).__init__()
|
|
self.gpu_ids = gpu_ids
|
|
|
|
kw = 4
|
|
padw = int(np.ceil((kw-1)/2))
|
|
sequence = [
|
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
for n in range(1, n_layers):
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=2, padding=padw),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n_layers, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=1, padding=padw),
|
|
norm_layer(ndf * nf_mult),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
|
|
|
if use_sigmoid:
|
|
sequence += [nn.Sigmoid()]
|
|
|
|
self.model = nn.Sequential(*sequence)
|
|
|
|
def forward(self, input):
|
|
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
|
|
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
|
|
# else:
|
|
return self.model(input)
|
|
|
|
class NoNormDiscriminator(nn.Module):
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
|
|
super(NoNormDiscriminator, self).__init__()
|
|
self.gpu_ids = gpu_ids
|
|
|
|
kw = 4
|
|
padw = int(np.ceil((kw-1)/2))
|
|
sequence = [
|
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
for n in range(1, n_layers):
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n_layers, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=1, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
|
|
|
if use_sigmoid:
|
|
sequence += [nn.Sigmoid()]
|
|
|
|
self.model = nn.Sequential(*sequence)
|
|
|
|
def forward(self, input):
|
|
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
|
|
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
|
|
# else:
|
|
return self.model(input)
|
|
|
|
class FCDiscriminator(nn.Module):
|
|
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False):
|
|
super(FCDiscriminator, self).__init__()
|
|
self.gpu_ids = gpu_ids
|
|
self.use_sigmoid = use_sigmoid
|
|
kw = 4
|
|
padw = int(np.ceil((kw-1)/2))
|
|
sequence = [
|
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult = 1
|
|
nf_mult_prev = 1
|
|
for n in range(1, n_layers):
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=2, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
nf_mult_prev = nf_mult
|
|
nf_mult = min(2**n_layers, 8)
|
|
sequence += [
|
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
|
kernel_size=kw, stride=1, padding=padw),
|
|
nn.LeakyReLU(0.2, True)
|
|
]
|
|
|
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
|
if patch:
|
|
self.linear = nn.Linear(7*7,1)
|
|
else:
|
|
self.linear = nn.Linear(13*13,1)
|
|
if use_sigmoid:
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
self.model = nn.Sequential(*sequence)
|
|
|
|
def forward(self, input):
|
|
batchsize = input.size()[0]
|
|
output = self.model(input)
|
|
output = output.view(batchsize,-1)
|
|
# print(output.size())
|
|
output = self.linear(output)
|
|
if self.use_sigmoid:
|
|
print("sigmoid")
|
|
output = self.sigmoid(output)
|
|
return output
|
|
|
|
|
|
class Unet_resize_conv(nn.Module):
|
|
def __init__(self, opt, skip):
|
|
super(Unet_resize_conv, self).__init__()
|
|
|
|
self.opt = opt
|
|
self.skip = skip
|
|
p = 1
|
|
# self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
|
|
if opt.self_attention:
|
|
self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
|
|
# self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
|
|
self.downsample_1 = nn.MaxPool2d(2)
|
|
self.downsample_2 = nn.MaxPool2d(2)
|
|
self.downsample_3 = nn.MaxPool2d(2)
|
|
self.downsample_4 = nn.MaxPool2d(2)
|
|
else:
|
|
self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
|
|
self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
|
|
self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
|
|
self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
|
|
self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
|
|
|
|
self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
|
|
self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
|
|
self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
|
|
self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
|
|
self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
|
|
|
|
self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
|
|
self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
|
|
self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
|
|
self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
|
|
self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
|
|
|
|
self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p)
|
|
self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
|
|
self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p)
|
|
self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
|
|
self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
|
|
|
|
self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
|
|
self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
|
|
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
|
|
self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
|
|
|
|
# self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2)
|
|
self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
|
|
self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
|
|
self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
|
|
self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
|
|
self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
|
|
|
|
# self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2)
|
|
self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
|
|
self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
|
|
self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
|
|
self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
|
|
self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
|
|
|
|
# self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2)
|
|
self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
|
|
self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
|
|
self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
|
|
self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
|
|
self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
|
|
|
|
# self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2)
|
|
self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
|
|
self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
|
|
self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
|
|
if self.opt.use_norm == 1:
|
|
self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
|
|
self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
|
|
self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)
|
|
|
|
self.conv10 = nn.Conv2d(32, 3, 1)
|
|
if self.opt.tanh:
|
|
self.tanh = nn.Tanh()
|
|
|
|
def depth_to_space(self, input, block_size):
|
|
block_size_sq = block_size*block_size
|
|
output = input.permute(0, 2, 3, 1)
|
|
(batch_size, d_height, d_width, d_depth) = output.size()
|
|
s_depth = int(d_depth / block_size_sq)
|
|
s_width = int(d_width * block_size)
|
|
s_height = int(d_height * block_size)
|
|
t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth)
|
|
spl = t_1.split(block_size, 3)
|
|
stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl]
|
|
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth)
|
|
output = output.permute(0, 3, 1, 2)
|
|
return output
|
|
|
|
def forward(self, input, gray):
|
|
flag = 0
|
|
if input.size()[3] > 2200:
|
|
avg = nn.AvgPool2d(2)
|
|
input = avg(input)
|
|
gray = avg(gray)
|
|
flag = 1
|
|
# pass
|
|
input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input)
|
|
gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray)
|
|
if self.opt.self_attention:
|
|
gray_2 = self.downsample_1(gray)
|
|
gray_3 = self.downsample_2(gray_2)
|
|
gray_4 = self.downsample_3(gray_3)
|
|
gray_5 = self.downsample_4(gray_4)
|
|
if self.opt.use_norm == 1:
|
|
if self.opt.self_attention:
|
|
x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))))
|
|
# x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
|
|
else:
|
|
x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
|
|
conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
|
|
x = self.max_pool1(conv1)
|
|
|
|
x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
|
|
conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
|
|
x = self.max_pool2(conv2)
|
|
|
|
x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
|
|
conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
|
|
x = self.max_pool3(conv3)
|
|
|
|
x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
|
|
conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
|
|
x = self.max_pool4(conv4)
|
|
|
|
x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
|
|
x = x*gray_5 if self.opt.self_attention else x
|
|
conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))
|
|
|
|
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
|
|
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
|
|
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
|
|
x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
|
|
conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
|
|
|
|
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
|
|
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
|
|
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
|
|
x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
|
|
conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))
|
|
|
|
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
|
|
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
|
|
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
|
|
x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
|
|
conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))
|
|
|
|
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
|
|
conv1 = conv1*gray if self.opt.self_attention else conv1
|
|
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
|
|
x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
|
|
conv9 = self.LReLU9_2(self.conv9_2(x))
|
|
|
|
latent = self.conv10(conv9)
|
|
|
|
if self.opt.times_residual:
|
|
latent = latent*gray
|
|
|
|
# output = self.depth_to_space(conv10, 2)
|
|
if self.opt.tanh:
|
|
latent = self.tanh(latent)
|
|
if self.skip:
|
|
if self.opt.linear_add:
|
|
if self.opt.latent_threshold:
|
|
latent = F.relu(latent)
|
|
elif self.opt.latent_norm:
|
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
|
|
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
|
|
output = latent + input*self.opt.skip
|
|
output = output*2 - 1
|
|
else:
|
|
if self.opt.latent_threshold:
|
|
latent = F.relu(latent)
|
|
elif self.opt.latent_norm:
|
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
|
|
output = latent + input*self.opt.skip
|
|
else:
|
|
output = latent
|
|
|
|
if self.opt.linear:
|
|
output = output/torch.max(torch.abs(output))
|
|
|
|
|
|
elif self.opt.use_norm == 0:
|
|
if self.opt.self_attention:
|
|
x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))
|
|
else:
|
|
x = self.LReLU1_1(self.conv1_1(input))
|
|
conv1 = self.LReLU1_2(self.conv1_2(x))
|
|
x = self.max_pool1(conv1)
|
|
|
|
x = self.LReLU2_1(self.conv2_1(x))
|
|
conv2 = self.LReLU2_2(self.conv2_2(x))
|
|
x = self.max_pool2(conv2)
|
|
|
|
x = self.LReLU3_1(self.conv3_1(x))
|
|
conv3 = self.LReLU3_2(self.conv3_2(x))
|
|
x = self.max_pool3(conv3)
|
|
|
|
x = self.LReLU4_1(self.conv4_1(x))
|
|
conv4 = self.LReLU4_2(self.conv4_2(x))
|
|
x = self.max_pool4(conv4)
|
|
|
|
x = self.LReLU5_1(self.conv5_1(x))
|
|
x = x*gray_5 if self.opt.self_attention else x
|
|
conv5 = self.LReLU5_2(self.conv5_2(x))
|
|
|
|
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
|
|
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
|
|
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
|
|
x = self.LReLU6_1(self.conv6_1(up6))
|
|
conv6 = self.LReLU6_2(self.conv6_2(x))
|
|
|
|
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
|
|
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
|
|
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
|
|
x = self.LReLU7_1(self.conv7_1(up7))
|
|
conv7 = self.LReLU7_2(self.conv7_2(x))
|
|
|
|
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
|
|
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
|
|
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
|
|
x = self.LReLU8_1(self.conv8_1(up8))
|
|
conv8 = self.LReLU8_2(self.conv8_2(x))
|
|
|
|
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
|
|
conv1 = conv1*gray if self.opt.self_attention else conv1
|
|
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
|
|
x = self.LReLU9_1(self.conv9_1(up9))
|
|
conv9 = self.LReLU9_2(self.conv9_2(x))
|
|
|
|
latent = self.conv10(conv9)
|
|
|
|
if self.opt.times_residual:
|
|
latent = latent*gray
|
|
|
|
if self.opt.tanh:
|
|
latent = self.tanh(latent)
|
|
if self.skip:
|
|
if self.opt.linear_add:
|
|
if self.opt.latent_threshold:
|
|
latent = F.relu(latent)
|
|
elif self.opt.latent_norm:
|
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
|
|
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
|
|
output = latent + input*self.opt.skip
|
|
output = output*2 - 1
|
|
else:
|
|
if self.opt.latent_threshold:
|
|
latent = F.relu(latent)
|
|
elif self.opt.latent_norm:
|
|
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
|
|
output = latent + input*self.opt.skip
|
|
else:
|
|
output = latent
|
|
|
|
if self.opt.linear:
|
|
output = output/torch.max(torch.abs(output))
|
|
|
|
output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom)
|
|
latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom)
|
|
gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom)
|
|
if flag == 1:
|
|
output = F.upsample(output, scale_factor=2, mode='bilinear')
|
|
gray = F.upsample(gray, scale_factor=2, mode='bilinear')
|
|
if self.skip:
|
|
return output, latent
|
|
else:
|
|
return output
|
|
|
|
class DnCNN(nn.Module):
|
|
def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
|
|
super(DnCNN, self).__init__()
|
|
kernel_size = 3
|
|
padding = 1
|
|
layers = []
|
|
|
|
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
for _ in range(depth-2):
|
|
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
|
|
self.dncnn = nn.Sequential(*layers)
|
|
self._initialize_weights()
|
|
|
|
def forward(self, x):
|
|
y = x
|
|
out = self.dncnn(x)
|
|
return y+out
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
init.orthogonal_(m.weight)
|
|
print('init weight')
|
|
if m.bias is not None:
|
|
init.constant_(m.bias, 0)
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
init.constant_(m.weight, 1)
|
|
init.constant_(m.bias, 0)
|
|
|
|
class Vgg16(nn.Module):
|
|
def __init__(self):
|
|
super(Vgg16, self).__init__()
|
|
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
|
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, X, opt):
|
|
h = F.relu(self.conv1_1(X), inplace=True)
|
|
h = F.relu(self.conv1_2(h), inplace=True)
|
|
# relu1_2 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv2_1(h), inplace=True)
|
|
h = F.relu(self.conv2_2(h), inplace=True)
|
|
# relu2_2 = h
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv3_1(h), inplace=True)
|
|
h = F.relu(self.conv3_2(h), inplace=True)
|
|
h = F.relu(self.conv3_3(h), inplace=True)
|
|
# relu3_3 = h
|
|
if opt.vgg_choose != "no_maxpool":
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
h = F.relu(self.conv4_1(h), inplace=True)
|
|
relu4_1 = h
|
|
h = F.relu(self.conv4_2(h), inplace=True)
|
|
relu4_2 = h
|
|
conv4_3 = self.conv4_3(h)
|
|
h = F.relu(conv4_3, inplace=True)
|
|
relu4_3 = h
|
|
|
|
if opt.vgg_choose != "no_maxpool":
|
|
if opt.vgg_maxpooling:
|
|
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
|
|
|
relu5_1 = F.relu(self.conv5_1(h), inplace=True)
|
|
relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True)
|
|
conv5_3 = self.conv5_3(relu5_2)
|
|
h = F.relu(conv5_3, inplace=True)
|
|
relu5_3 = h
|
|
if opt.vgg_choose == "conv4_3":
|
|
return conv4_3
|
|
elif opt.vgg_choose == "relu4_2":
|
|
return relu4_2
|
|
elif opt.vgg_choose == "relu4_1":
|
|
return relu4_1
|
|
elif opt.vgg_choose == "relu4_3":
|
|
return relu4_3
|
|
elif opt.vgg_choose == "conv5_3":
|
|
return conv5_3
|
|
elif opt.vgg_choose == "relu5_1":
|
|
return relu5_1
|
|
elif opt.vgg_choose == "relu5_2":
|
|
return relu5_2
|
|
elif opt.vgg_choose == "relu5_3" or "maxpool":
|
|
return relu5_3
|
|
|
|
def vgg_preprocess(batch, opt):
|
|
tensortype = type(batch.data)
|
|
(r, g, b) = torch.chunk(batch, 3, dim = 1)
|
|
batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
|
|
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
|
|
if opt.vgg_mean:
|
|
mean = tensortype(batch.data.size())
|
|
mean[:, 0, :, :] = 103.939
|
|
mean[:, 1, :, :] = 116.779
|
|
mean[:, 2, :, :] = 123.680
|
|
batch = batch.sub(Variable(mean)) # subtract mean
|
|
return batch
|
|
|
|
class PerceptualLoss(nn.Module):
|
|
def __init__(self, opt):
|
|
super(PerceptualLoss, self).__init__()
|
|
self.opt = opt
|
|
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
|
|
|
|
def compute_vgg_loss(self, vgg, img, target):
|
|
img_vgg = vgg_preprocess(img, self.opt)
|
|
target_vgg = vgg_preprocess(target, self.opt)
|
|
img_fea = vgg(img_vgg, self.opt)
|
|
target_fea = vgg(target_vgg, self.opt)
|
|
if self.opt.no_vgg_instance:
|
|
return torch.mean((img_fea - target_fea) ** 2)
|
|
else:
|
|
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
|
|
|
|
def load_vgg16(model_dir, gpu_ids):
|
|
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
|
|
if not os.path.exists(model_dir):
|
|
os.mkdir(model_dir)
|
|
# if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
|
|
# if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
|
|
# os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
|
|
# vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
|
|
# vgg = Vgg16()
|
|
# for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
|
|
# dst.data[:] = src
|
|
# torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
|
|
vgg = Vgg16()
|
|
# vgg.cuda()
|
|
vgg.cuda(device=gpu_ids[0])
|
|
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
|
|
vgg = torch.nn.DataParallel(vgg, gpu_ids)
|
|
return vgg
|
|
|
|
|
|
|
|
class FCN32s(nn.Module):
|
|
def __init__(self, n_class=21):
|
|
super(FCN32s, self).__init__()
|
|
# conv1
|
|
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
|
|
self.relu1_1 = nn.ReLU(inplace=True)
|
|
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
|
|
self.relu1_2 = nn.ReLU(inplace=True)
|
|
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
|
|
|
|
# conv2
|
|
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
|
|
self.relu2_1 = nn.ReLU(inplace=True)
|
|
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
|
|
self.relu2_2 = nn.ReLU(inplace=True)
|
|
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
|
|
|
|
# conv3
|
|
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
|
|
self.relu3_1 = nn.ReLU(inplace=True)
|
|
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
|
|
self.relu3_2 = nn.ReLU(inplace=True)
|
|
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
|
|
self.relu3_3 = nn.ReLU(inplace=True)
|
|
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
|
|
|
|
# conv4
|
|
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
|
|
self.relu4_1 = nn.ReLU(inplace=True)
|
|
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
|
|
self.relu4_2 = nn.ReLU(inplace=True)
|
|
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
|
|
self.relu4_3 = nn.ReLU(inplace=True)
|
|
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
|
|
|
|
# conv5
|
|
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
|
|
self.relu5_1 = nn.ReLU(inplace=True)
|
|
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
|
|
self.relu5_2 = nn.ReLU(inplace=True)
|
|
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
|
|
self.relu5_3 = nn.ReLU(inplace=True)
|
|
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
|
|
|
|
# fc6
|
|
self.fc6 = nn.Conv2d(512, 4096, 7)
|
|
self.relu6 = nn.ReLU(inplace=True)
|
|
self.drop6 = nn.Dropout2d()
|
|
|
|
# fc7
|
|
self.fc7 = nn.Conv2d(4096, 4096, 1)
|
|
self.relu7 = nn.ReLU(inplace=True)
|
|
self.drop7 = nn.Dropout2d()
|
|
|
|
self.score_fr = nn.Conv2d(4096, n_class, 1)
|
|
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
|
|
bias=False)
|
|
|
|
|
|
def _initialize_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
m.weight.data.zero_()
|
|
if m.bias is not None:
|
|
m.bias.data.zero_()
|
|
if isinstance(m, nn.ConvTranspose2d):
|
|
assert m.kernel_size[0] == m.kernel_size[1]
|
|
initial_weight = get_upsampling_weight(
|
|
m.in_channels, m.out_channels, m.kernel_size[0])
|
|
m.weight.data.copy_(initial_weight)
|
|
|
|
def forward(self, x):
|
|
h = x
|
|
h = self.relu1_1(self.conv1_1(h))
|
|
h = self.relu1_2(self.conv1_2(h))
|
|
h = self.pool1(h)
|
|
|
|
h = self.relu2_1(self.conv2_1(h))
|
|
h = self.relu2_2(self.conv2_2(h))
|
|
h = self.pool2(h)
|
|
|
|
h = self.relu3_1(self.conv3_1(h))
|
|
h = self.relu3_2(self.conv3_2(h))
|
|
h = self.relu3_3(self.conv3_3(h))
|
|
h = self.pool3(h)
|
|
|
|
h = self.relu4_1(self.conv4_1(h))
|
|
h = self.relu4_2(self.conv4_2(h))
|
|
h = self.relu4_3(self.conv4_3(h))
|
|
h = self.pool4(h)
|
|
|
|
h = self.relu5_1(self.conv5_1(h))
|
|
h = self.relu5_2(self.conv5_2(h))
|
|
h = self.relu5_3(self.conv5_3(h))
|
|
h = self.pool5(h)
|
|
|
|
h = self.relu6(self.fc6(h))
|
|
h = self.drop6(h)
|
|
|
|
h = self.relu7(self.fc7(h))
|
|
h = self.drop7(h)
|
|
|
|
h = self.score_fr(h)
|
|
|
|
h = self.upscore(h)
|
|
h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
|
|
return h
|
|
|
|
def load_fcn(model_dir):
|
|
fcn = FCN32s()
|
|
fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth')))
|
|
fcn.cuda()
|
|
return fcn
|
|
|
|
class SemanticLoss(nn.Module):
|
|
def __init__(self, opt):
|
|
super(SemanticLoss, self).__init__()
|
|
self.opt = opt
|
|
self.instancenorm = nn.InstanceNorm2d(21, affine=False)
|
|
|
|
def compute_fcn_loss(self, fcn, img, target):
|
|
img_fcn = vgg_preprocess(img, self.opt)
|
|
target_fcn = vgg_preprocess(target, self.opt)
|
|
img_fea = fcn(img_fcn)
|
|
target_fea = fcn(target_fcn)
|
|
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
|