You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1182 lines
47 KiB
Python

import torch
import os
import math
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
# from torch.utils.serialization import load_lua
from lib.nn import SynchronizedBatchNorm2d as SynBN2d
###############################################################################
# Functions
###############################################################################
def pad_tensor(input):
height_org, width_org = input.shape[2], input.shape[3]
divide = 16
if width_org % divide != 0 or height_org % divide != 0:
width_res = width_org % divide
height_res = height_org % divide
if width_res != 0:
width_div = divide - width_res
pad_left = int(width_div / 2)
pad_right = int(width_div - pad_left)
else:
pad_left = 0
pad_right = 0
if height_res != 0:
height_div = divide - height_res
pad_top = int(height_div / 2)
pad_bottom = int(height_div - pad_top)
else:
pad_top = 0
pad_bottom = 0
padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom))
input = padding(input)
else:
pad_left = 0
pad_right = 0
pad_top = 0
pad_bottom = 0
height, width = input.data.shape[2], input.data.shape[3]
assert width % divide == 0, 'width cant divided by stride'
assert height % divide == 0, 'height cant divided by stride'
return input, pad_left, pad_right, pad_top, pad_bottom
def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom):
height, width = input.shape[2], input.shape[3]
return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right]
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'synBN':
norm_layer = functools.partial(SynBN2d, affine=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return norm_layer
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
# if use_gpu:
# assert(torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'unet_512':
netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt)
elif which_model_netG == 'sid_unet':
netG = Unet(opt, skip)
elif which_model_netG == 'sid_unet_shuffle':
netG = Unet_pixelshuffle(opt, skip)
elif which_model_netG == 'sid_unet_resize':
netG = Unet_resize_conv(opt, skip)
elif which_model_netG == 'DnCNN':
netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if torch.cuda.is_available() and not opt.cFlag:
netG.cuda(device=gpu_ids[0])
# netG = torch.nn.DataParallel(netG, gpu_ids)
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_norm_4':
netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'no_patchgan':
netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(device=gpu_ids[0])
netD = torch.nn.DataParallel(netD, gpu_ids)
netD.apply(weights_init)
return netD
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
class DiscLossWGANGP():
def __init__(self):
self.LAMBDA = 10
def name(self):
return 'DiscLossWGAN-GP'
def initialize(self, opt, tensor):
# DiscLossLS.initialize(self, opt, tensor)
self.LAMBDA = 10
# def get_g_loss(self, net, realA, fakeB):
# # First, G(A) should fake the discriminator
# self.D_fake = net.forward(fakeB)
# return -self.D_fake.mean()
def calc_gradient_penalty(self, netD, real_data, fake_data):
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda()
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD.forward(interpolates)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
return gradient_penalty
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
self.opt = opt
# currently support only input_nc == output_nc
assert(input_nc == output_nc)
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt)
if skip == True:
skipmodule = SkipModule(unet_block, opt)
self.model = skipmodule
else:
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
class SkipModule(nn.Module):
def __init__(self, submodule, opt):
super(SkipModule, self).__init__()
self.submodule = submodule
self.opt = opt
def forward(self, x):
latent = self.submodule(x)
return self.opt.skip*x + latent, latent
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if opt.use_norm == 0:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
else:
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([self.model(x), x], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
# else:
return self.model(input)
class NoNormDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
super(NoNormDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
# if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
# return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
# else:
return self.model(input)
class FCDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False):
super(FCDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
self.use_sigmoid = use_sigmoid
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if patch:
self.linear = nn.Linear(7*7,1)
else:
self.linear = nn.Linear(13*13,1)
if use_sigmoid:
self.sigmoid = nn.Sigmoid()
self.model = nn.Sequential(*sequence)
def forward(self, input):
batchsize = input.size()[0]
output = self.model(input)
output = output.view(batchsize,-1)
# print(output.size())
output = self.linear(output)
if self.use_sigmoid:
print("sigmoid")
output = self.sigmoid(output)
return output
class Unet_resize_conv(nn.Module):
def __init__(self, opt, skip):
super(Unet_resize_conv, self).__init__()
self.opt = opt
self.skip = skip
p = 1
# self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
if opt.self_attention:
self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p)
# self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
self.downsample_1 = nn.MaxPool2d(2)
self.downsample_2 = nn.MaxPool2d(2)
self.downsample_3 = nn.MaxPool2d(2)
self.downsample_4 = nn.MaxPool2d(2)
else:
self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p)
self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p)
self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p)
self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p)
self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p)
self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p)
self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p)
self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p)
self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2)
self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p)
self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p)
self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512)
# self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.deconv5 = nn.Conv2d(512, 256, 3, padding=p)
self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p)
self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p)
self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256)
# self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.deconv6 = nn.Conv2d(256, 128, 3, padding=p)
self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p)
self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p)
self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128)
# self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.deconv7 = nn.Conv2d(128, 64, 3, padding=p)
self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p)
self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p)
self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64)
# self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.deconv8 = nn.Conv2d(64, 32, 3, padding=p)
self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p)
self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True)
if self.opt.use_norm == 1:
self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32)
self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p)
self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True)
self.conv10 = nn.Conv2d(32, 3, 1)
if self.opt.tanh:
self.tanh = nn.Tanh()
def depth_to_space(self, input, block_size):
block_size_sq = block_size*block_size
output = input.permute(0, 2, 3, 1)
(batch_size, d_height, d_width, d_depth) = output.size()
s_depth = int(d_depth / block_size_sq)
s_width = int(d_width * block_size)
s_height = int(d_height * block_size)
t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth)
spl = t_1.split(block_size, 3)
stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl]
output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth)
output = output.permute(0, 3, 1, 2)
return output
def forward(self, input, gray):
flag = 0
if input.size()[3] > 2200:
avg = nn.AvgPool2d(2)
input = avg(input)
gray = avg(gray)
flag = 1
# pass
input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input)
gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray)
if self.opt.self_attention:
gray_2 = self.downsample_1(gray)
gray_3 = self.downsample_2(gray_2)
gray_4 = self.downsample_3(gray_3)
gray_5 = self.downsample_4(gray_4)
if self.opt.use_norm == 1:
if self.opt.self_attention:
x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))))
# x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
else:
x = self.bn1_1(self.LReLU1_1(self.conv1_1(input)))
conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x)))
x = self.max_pool1(conv1)
x = self.bn2_1(self.LReLU2_1(self.conv2_1(x)))
conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x)))
x = self.max_pool2(conv2)
x = self.bn3_1(self.LReLU3_1(self.conv3_1(x)))
conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x)))
x = self.max_pool3(conv3)
x = self.bn4_1(self.LReLU4_1(self.conv4_1(x)))
conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x)))
x = self.max_pool4(conv4)
x = self.bn5_1(self.LReLU5_1(self.conv5_1(x)))
x = x*gray_5 if self.opt.self_attention else x
conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x)))
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6)))
conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x)))
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7)))
conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x)))
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8)))
conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x)))
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
conv1 = conv1*gray if self.opt.self_attention else conv1
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9)))
conv9 = self.LReLU9_2(self.conv9_2(x))
latent = self.conv10(conv9)
if self.opt.times_residual:
latent = latent*gray
# output = self.depth_to_space(conv10, 2)
if self.opt.tanh:
latent = self.tanh(latent)
if self.skip:
if self.opt.linear_add:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
output = latent + input*self.opt.skip
output = output*2 - 1
else:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
output = latent + input*self.opt.skip
else:
output = latent
if self.opt.linear:
output = output/torch.max(torch.abs(output))
elif self.opt.use_norm == 0:
if self.opt.self_attention:
x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))
else:
x = self.LReLU1_1(self.conv1_1(input))
conv1 = self.LReLU1_2(self.conv1_2(x))
x = self.max_pool1(conv1)
x = self.LReLU2_1(self.conv2_1(x))
conv2 = self.LReLU2_2(self.conv2_2(x))
x = self.max_pool2(conv2)
x = self.LReLU3_1(self.conv3_1(x))
conv3 = self.LReLU3_2(self.conv3_2(x))
x = self.max_pool3(conv3)
x = self.LReLU4_1(self.conv4_1(x))
conv4 = self.LReLU4_2(self.conv4_2(x))
x = self.max_pool4(conv4)
x = self.LReLU5_1(self.conv5_1(x))
x = x*gray_5 if self.opt.self_attention else x
conv5 = self.LReLU5_2(self.conv5_2(x))
conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear')
conv4 = conv4*gray_4 if self.opt.self_attention else conv4
up6 = torch.cat([self.deconv5(conv5), conv4], 1)
x = self.LReLU6_1(self.conv6_1(up6))
conv6 = self.LReLU6_2(self.conv6_2(x))
conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear')
conv3 = conv3*gray_3 if self.opt.self_attention else conv3
up7 = torch.cat([self.deconv6(conv6), conv3], 1)
x = self.LReLU7_1(self.conv7_1(up7))
conv7 = self.LReLU7_2(self.conv7_2(x))
conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear')
conv2 = conv2*gray_2 if self.opt.self_attention else conv2
up8 = torch.cat([self.deconv7(conv7), conv2], 1)
x = self.LReLU8_1(self.conv8_1(up8))
conv8 = self.LReLU8_2(self.conv8_2(x))
conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear')
conv1 = conv1*gray if self.opt.self_attention else conv1
up9 = torch.cat([self.deconv8(conv8), conv1], 1)
x = self.LReLU9_1(self.conv9_1(up9))
conv9 = self.LReLU9_2(self.conv9_2(x))
latent = self.conv10(conv9)
if self.opt.times_residual:
latent = latent*gray
if self.opt.tanh:
latent = self.tanh(latent)
if self.skip:
if self.opt.linear_add:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
input = (input - torch.min(input))/(torch.max(input) - torch.min(input))
output = latent + input*self.opt.skip
output = output*2 - 1
else:
if self.opt.latent_threshold:
latent = F.relu(latent)
elif self.opt.latent_norm:
latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent))
output = latent + input*self.opt.skip
else:
output = latent
if self.opt.linear:
output = output/torch.max(torch.abs(output))
output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom)
latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom)
gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom)
if flag == 1:
output = F.upsample(output, scale_factor=2, mode='bilinear')
gray = F.upsample(gray, scale_factor=2, mode='bilinear')
if self.skip:
return output, latent
else:
return output
class DnCNN(nn.Module):
def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
layers = []
layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
self._initialize_weights()
def forward(self, x):
y = x
out = self.dncnn(x)
return y+out
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.orthogonal_(m.weight)
print('init weight')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
class Vgg16(nn.Module):
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
def forward(self, X, opt):
h = F.relu(self.conv1_1(X), inplace=True)
h = F.relu(self.conv1_2(h), inplace=True)
# relu1_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv2_1(h), inplace=True)
h = F.relu(self.conv2_2(h), inplace=True)
# relu2_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv3_1(h), inplace=True)
h = F.relu(self.conv3_2(h), inplace=True)
h = F.relu(self.conv3_3(h), inplace=True)
# relu3_3 = h
if opt.vgg_choose != "no_maxpool":
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv4_1(h), inplace=True)
relu4_1 = h
h = F.relu(self.conv4_2(h), inplace=True)
relu4_2 = h
conv4_3 = self.conv4_3(h)
h = F.relu(conv4_3, inplace=True)
relu4_3 = h
if opt.vgg_choose != "no_maxpool":
if opt.vgg_maxpooling:
h = F.max_pool2d(h, kernel_size=2, stride=2)
relu5_1 = F.relu(self.conv5_1(h), inplace=True)
relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True)
conv5_3 = self.conv5_3(relu5_2)
h = F.relu(conv5_3, inplace=True)
relu5_3 = h
if opt.vgg_choose == "conv4_3":
return conv4_3
elif opt.vgg_choose == "relu4_2":
return relu4_2
elif opt.vgg_choose == "relu4_1":
return relu4_1
elif opt.vgg_choose == "relu4_3":
return relu4_3
elif opt.vgg_choose == "conv5_3":
return conv5_3
elif opt.vgg_choose == "relu5_1":
return relu5_1
elif opt.vgg_choose == "relu5_2":
return relu5_2
elif opt.vgg_choose == "relu5_3" or "maxpool":
return relu5_3
def vgg_preprocess(batch, opt):
tensortype = type(batch.data)
(r, g, b) = torch.chunk(batch, 3, dim = 1)
batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
if opt.vgg_mean:
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
batch = batch.sub(Variable(mean)) # subtract mean
return batch
class PerceptualLoss(nn.Module):
def __init__(self, opt):
super(PerceptualLoss, self).__init__()
self.opt = opt
self.instancenorm = nn.InstanceNorm2d(512, affine=False)
def compute_vgg_loss(self, vgg, img, target):
img_vgg = vgg_preprocess(img, self.opt)
target_vgg = vgg_preprocess(target, self.opt)
img_fea = vgg(img_vgg, self.opt)
target_fea = vgg(target_vgg, self.opt)
if self.opt.no_vgg_instance:
return torch.mean((img_fea - target_fea) ** 2)
else:
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
def load_vgg16(model_dir, gpu_ids):
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
if not os.path.exists(model_dir):
os.mkdir(model_dir)
# if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
# if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
# os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
# vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
# vgg = Vgg16()
# for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
# dst.data[:] = src
# torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
vgg = Vgg16()
# vgg.cuda()
vgg.cuda(device=gpu_ids[0])
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
vgg = torch.nn.DataParallel(vgg, gpu_ids)
return vgg
class FCN32s(nn.Module):
def __init__(self, n_class=21):
super(FCN32s, self).__init__()
# conv1
self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
self.relu1_2 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2
# conv2
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
self.relu2_2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4
# conv3
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_2 = nn.ReLU(inplace=True)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
self.relu3_3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8
# conv4
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_2 = nn.ReLU(inplace=True)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu4_3 = nn.ReLU(inplace=True)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16
# conv5
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_2 = nn.ReLU(inplace=True)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
self.relu5_3 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32
# fc6
self.fc6 = nn.Conv2d(512, 4096, 7)
self.relu6 = nn.ReLU(inplace=True)
self.drop6 = nn.Dropout2d()
# fc7
self.fc7 = nn.Conv2d(4096, 4096, 1)
self.relu7 = nn.ReLU(inplace=True)
self.drop7 = nn.Dropout2d()
self.score_fr = nn.Conv2d(4096, n_class, 1)
self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32,
bias=False)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.zero_()
if m.bias is not None:
m.bias.data.zero_()
if isinstance(m, nn.ConvTranspose2d):
assert m.kernel_size[0] == m.kernel_size[1]
initial_weight = get_upsampling_weight(
m.in_channels, m.out_channels, m.kernel_size[0])
m.weight.data.copy_(initial_weight)
def forward(self, x):
h = x
h = self.relu1_1(self.conv1_1(h))
h = self.relu1_2(self.conv1_2(h))
h = self.pool1(h)
h = self.relu2_1(self.conv2_1(h))
h = self.relu2_2(self.conv2_2(h))
h = self.pool2(h)
h = self.relu3_1(self.conv3_1(h))
h = self.relu3_2(self.conv3_2(h))
h = self.relu3_3(self.conv3_3(h))
h = self.pool3(h)
h = self.relu4_1(self.conv4_1(h))
h = self.relu4_2(self.conv4_2(h))
h = self.relu4_3(self.conv4_3(h))
h = self.pool4(h)
h = self.relu5_1(self.conv5_1(h))
h = self.relu5_2(self.conv5_2(h))
h = self.relu5_3(self.conv5_3(h))
h = self.pool5(h)
h = self.relu6(self.fc6(h))
h = self.drop6(h)
h = self.relu7(self.fc7(h))
h = self.drop7(h)
h = self.score_fr(h)
h = self.upscore(h)
h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous()
return h
def load_fcn(model_dir):
fcn = FCN32s()
fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth')))
fcn.cuda()
return fcn
class SemanticLoss(nn.Module):
def __init__(self, opt):
super(SemanticLoss, self).__init__()
self.opt = opt
self.instancenorm = nn.InstanceNorm2d(21, affine=False)
def compute_fcn_loss(self, fcn, img, target):
img_fcn = vgg_preprocess(img, self.opt)
target_fcn = vgg_preprocess(target, self.opt)
img_fea = fcn(img_fcn)
target_fea = fcn(target_fcn)
return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)