First upload

pull/10/head
kritiksoman 4 years ago
parent ba244f12d8
commit ccd3c980ca

@ -0,0 +1,100 @@
## Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset, make_dataset_test
from PIL import Image
import torch
import numpy as np
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps)
if opt.isTrain or opt.use_encoded_image:
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))
self.AR_paths = make_dataset(self.dir_A)
### input A inter 1 (label maps)
if opt.isTrain or opt.use_encoded_image:
dir_A_inter_1 = '_label_inter_1'
self.dir_A_inter_1 = os.path.join(opt.dataroot, opt.phase + dir_A_inter_1)
self.A_paths_inter_1 = sorted(make_dataset(self.dir_A_inter_1))
### input A inter 2 (label maps)
if opt.isTrain or opt.use_encoded_image:
dir_A_inter_2 = '_label_inter_2'
self.dir_A_inter_2 = os.path.join(opt.dataroot, opt.phase + dir_A_inter_2)
self.A_paths_inter_2 = sorted(make_dataset(self.dir_A_inter_2))
### input A test (label maps)
if not (opt.isTrain or opt.use_encoded_image):
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset_test(self.dir_A))
dir_AR = '_AR' if self.opt.label_nc == 0 else '_labelref'
self.dir_AR = os.path.join(opt.dataroot, opt.phase + dir_AR)
self.AR_paths = sorted(make_dataset_test(self.dir_AR))
### input B (real images)
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
self.BR_paths = sorted(make_dataset(self.dir_B))
self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
AR_path = self.AR_paths[index]
A = Image.open(A_path)
AR = Image.open(AR_path)
if self.opt.isTrain:
A_path_inter_1 = self.A_paths_inter_1[index]
A_path_inter_2 = self.A_paths_inter_2[index]
A_inter_1 = Image.open(A_path_inter_1)
A_inter_2 = Image.open(A_path_inter_2)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
if self.opt.isTrain:
A_inter_1_tensor = transform_A(A_inter_1.convert('RGB'))
A_inter_2_tensor = transform_A(A_inter_2.convert('RGB'))
AR_tensor = transform_A(AR.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0
if self.opt.isTrain:
A_inter_1_tensor = transform_A(A_inter_1) * 255.0
A_inter_2_tensor = transform_A(A_inter_2) * 255.0
AR_tensor = transform_A(AR) * 255.0
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
B_path = self.B_paths[index]
BR_path = self.BR_paths[index]
B = Image.open(B_path).convert('RGB')
BR = Image.open(BR_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
BR_tensor = transform_B(BR)
if self.opt.isTrain:
input_dict = {'inter_label_1': A_inter_1_tensor, 'label': A_tensor, 'inter_label_2': A_inter_2_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
else:
input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
return input_dict
def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
def name(self):
return 'AlignedDataset'

@ -0,0 +1,14 @@
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None

@ -0,0 +1,97 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
#flip = random.random() > 0.5
flip = 0
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, normalize_mask=False):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
if normalize_mask:
transform_list += [transforms.Normalize((0, 0, 0),
(1 / 255., 1 / 255., 1 / 255.))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img

@ -0,0 +1,31 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)

@ -0,0 +1,7 @@
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader

@ -0,0 +1,82 @@
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
f = dir.split('/')[-1].split('_')[-1]
print (dir, f)
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
if f == 'label' or f == '1' or f == '2':
img = str(i) + '.png'
else:
img = str(i) + '.jpg'
path = os.path.join(dir, img)
#print(path)
images.append(path)
return images
def make_dataset_test(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
f = dir.split('/')[-1].split('_')[-1]
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
if f == 'label' or f == 'labelref':
img = str(i) + '.png'
else:
img = str(i) + '.jpg'
path = os.path.join(dir, img)
#print(path)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

@ -0,0 +1,94 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
print (save_filename)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path))
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass

@ -0,0 +1,20 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids):
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model

@ -0,0 +1,818 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
else:
raise('generator not implemented!')
print(netG)
# if len(gpu_ids) > 0:
# assert(torch.cuda.is_available())
# netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
def define_VAE(input_nc, gpu_ids=[]):
netVAE = VAE(19, 32, 32, 1024)
print(netVAE)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netVAE.cuda(gpu_ids[0])
return netVAE
def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
print(netB)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netB.cuda(gpu_ids[0])
netB.apply(weights_init)
return netB
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
##############################################################################
# Generator
##############################################################################
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
# style encoder
self.enc_style = StyleEncoder(5, 3, 16, self.get_num_adain_params(self.model), norm='none', activ='relu', pad_type='reflect')
# label encoder
self.enc_label = LabelEncoder(5, 19, 16, 64, norm='none', activ='relu', pad_type='reflect')
def assign_adain_params(self, adain_params, model):
# assign the adain_params to the AdaIN layers in model
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
mean = adain_params[:, :m.num_features]
std = adain_params[:, m.num_features:2*m.num_features]
m.bias = mean.contiguous().view(-1)
m.weight = std.contiguous().view(-1)
if adain_params.size(1) > 2*m.num_features:
adain_params = adain_params[:, 2*m.num_features:]
def get_num_adain_params(self, model):
# return the number of AdaIN parameters needed by the model
num_adain_params = 0
for m in model.modules():
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
num_adain_params += 2*m.num_features
return num_adain_params
def forward(self, input, input_ref, image_ref):
fea1, fea2 = self.enc_label(input_ref)
adain_params = self.enc_style((image_ref, fea1, fea2))
self.assign_adain_params(adain_params, self.model)
return self.model(input)
class BlendGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(BlendGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
self.model = nn.Sequential(*model)
def forward(self, input1, input2):
m = self.model(torch.cat([input1, input2], 1))
return input1 * m + input2 * (1-m), m
# Define the Multiscale Discriminator.
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Define the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
# Define the MaskVAE
class VAE(nn.Module):
def __init__(self, nc, ngf, ndf, latent_variable_size):
super(VAE, self).__init__()
#self.cuda = True
self.nc = nc
self.ngf = ngf
self.ndf = ndf
self.latent_variable_size = latent_variable_size
# encoder
self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
self.bn1 = nn.BatchNorm2d(ndf)
self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(ndf*2)
self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(ndf*4)
self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(ndf*8)
self.e5 = nn.Conv2d(ndf*8, ndf*16, 4, 2, 1)
self.bn5 = nn.BatchNorm2d(ndf*16)
self.e6 = nn.Conv2d(ndf*16, ndf*32, 4, 2, 1)
self.bn6 = nn.BatchNorm2d(ndf*32)
self.e7 = nn.Conv2d(ndf*32, ndf*64, 4, 2, 1)
self.bn7 = nn.BatchNorm2d(ndf*64)
self.fc1 = nn.Linear(ndf*64*4*4, latent_variable_size)
self.fc2 = nn.Linear(ndf*64*4*4, latent_variable_size)
# decoder
self.d1 = nn.Linear(latent_variable_size, ngf*64*4*4)
self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd1 = nn.ReplicationPad2d(1)
self.d2 = nn.Conv2d(ngf*64, ngf*32, 3, 1)
self.bn8 = nn.BatchNorm2d(ngf*32, 1.e-3)
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd2 = nn.ReplicationPad2d(1)
self.d3 = nn.Conv2d(ngf*32, ngf*16, 3, 1)
self.bn9 = nn.BatchNorm2d(ngf*16, 1.e-3)
self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd3 = nn.ReplicationPad2d(1)
self.d4 = nn.Conv2d(ngf*16, ngf*8, 3, 1)
self.bn10 = nn.BatchNorm2d(ngf*8, 1.e-3)
self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd4 = nn.ReplicationPad2d(1)
self.d5 = nn.Conv2d(ngf*8, ngf*4, 3, 1)
self.bn11 = nn.BatchNorm2d(ngf*4, 1.e-3)
self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd5 = nn.ReplicationPad2d(1)
self.d6 = nn.Conv2d(ngf*4, ngf*2, 3, 1)
self.bn12 = nn.BatchNorm2d(ngf*2, 1.e-3)
self.up6 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd6 = nn.ReplicationPad2d(1)
self.d7 = nn.Conv2d(ngf*2, ngf, 3, 1)
self.bn13 = nn.BatchNorm2d(ngf, 1.e-3)
self.up7 = nn.UpsamplingNearest2d(scale_factor=2)
self.pd7 = nn.ReplicationPad2d(1)
self.d8 = nn.Conv2d(ngf, nc, 3, 1)
self.leakyrelu = nn.LeakyReLU(0.2)
self.relu = nn.ReLU()
#self.sigmoid = nn.Sigmoid()
self.maxpool = nn.MaxPool2d((2, 2), (2, 2))
def encode(self, x):
h1 = self.leakyrelu(self.bn1(self.e1(x)))
h2 = self.leakyrelu(self.bn2(self.e2(h1)))
h3 = self.leakyrelu(self.bn3(self.e3(h2)))
h4 = self.leakyrelu(self.bn4(self.e4(h3)))
h5 = self.leakyrelu(self.bn5(self.e5(h4)))
h6 = self.leakyrelu(self.bn6(self.e6(h5)))
h7 = self.leakyrelu(self.bn7(self.e7(h6)))
h7 = h7.view(-1, self.ndf*64*4*4)
return self.fc1(h7), self.fc2(h7)
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
#if self.cuda:
eps = torch.cuda.FloatTensor(std.size()).normal_()
#else:
# eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def decode(self, z):
h1 = self.relu(self.d1(z))
h1 = h1.view(-1, self.ngf*64, 4, 4)
h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1)))))
h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2)))))
h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3)))))
h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4)))))
h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5)))))
h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6)))))
return self.d8(self.pd7(self.up7(h7)))
def get_latent_var(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return z, mu, logvar.mul(0.5).exp_()
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
res = self.decode(z)
return res, x, mu, logvar
# style encode part
class StyleEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(StyleEncoder, self).__init__()
self.model = []
self.model_middle = []
self.model_last = []
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
for i in range(2):
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 2):
self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling
self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
self.model = nn.Sequential(*self.model)
self.model_middle = nn.Sequential(*self.model_middle)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
self.sft1 = SFTLayer()
self.sft2 = SFTLayer()
def forward(self, x):
fea = self.model(x[0])
fea = self.sft1((fea, x[1]))
fea = self.model_middle(fea)
fea = self.sft2((fea, x[2]))
return self.model_last(fea)
# label encode part
class LabelEncoder(nn.Module):
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
super(LabelEncoder, self).__init__()
self.model = []
self.model_last = [nn.ReLU()]
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
dim *= 2
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
dim *= 2
for i in range(n_downsample - 3):
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
self.model = nn.Sequential(*self.model)
self.model_last = nn.Sequential(*self.model_last)
self.output_dim = dim
def forward(self, x):
fea = self.model(x)
return fea, self.model_last(fea)
# Define the basic block
class ConvBlock(nn.Module):
def __init__(self, input_dim ,output_dim, kernel_size, stride,
padding=0, norm='none', activation='relu', pad_type='zero'):
super(ConvBlock, self).__init__()
self.use_bias = True
# initialize padding
if pad_type == 'reflect':
self.pad = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
self.pad = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
self.pad = nn.ZeroPad2d(padding)
else:
assert 0, "Unsupported padding type: {}".format(pad_type)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm2d(norm_dim)
elif norm == 'in':
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
self.norm = nn.InstanceNorm2d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'adain':
self.norm = AdaptiveInstanceNorm2d(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
# initialize convolution
if norm == 'sn':
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
else:
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
def forward(self, x):
x = self.conv(self.pad(x))
if self.norm:
x = self.norm(x)
if self.activation:
x = self.activation(x)
return x
class LinearBlock(nn.Module):
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
super(LinearBlock, self).__init__()
use_bias = True
# initialize fully connected layer
if norm == 'sn':
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
else:
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
# initialize normalization
norm_dim = output_dim
if norm == 'bn':
self.norm = nn.BatchNorm1d(norm_dim)
elif norm == 'in':
self.norm = nn.InstanceNorm1d(norm_dim)
elif norm == 'ln':
self.norm = LayerNorm(norm_dim)
elif norm == 'none' or norm == 'sn':
self.norm = None
else:
assert 0, "Unsupported normalization: {}".format(norm)
# initialize activation
if activation == 'relu':
self.activation = nn.ReLU(inplace=True)
elif activation == 'lrelu':
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == 'prelu':
self.activation = nn.PReLU()
elif activation == 'selu':
self.activation = nn.SELU(inplace=True)
elif activation == 'tanh':
self.activation = nn.Tanh()
elif activation == 'none':
self.activation = None
else:
assert 0, "Unsupported activation: {}".format(activation)
def forward(self, x):
out = self.fc(x)
if self.norm:
out = self.norm(out)
if self.activation:
out = self.activation(out)
return out
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout)
def build_conv_block(self, dim, norm_type, padding_type, use_dropout):
conv_block = []
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)]
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class SFTLayer(nn.Module):
def __init__(self):
super(SFTLayer, self).__init__()
self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1)
self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1)
def forward(self, x):
scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True))
shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True))
return x[0] * scale + shift
class ConvBlock_SFT(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return (x[0] + fea, x[1])
class ConvBlock_SFT_last(nn.Module):
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
super(ResnetBlock_SFT_last, self).__init__()
self.sft1 = SFTLayer()
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
def forward(self, x):
fea = self.sft1((x[0], x[1]))
fea = F.relu(self.conv1(fea), inplace=True)
return x[0] + fea
# Definition of normalization layer
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(AdaptiveInstanceNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# weight and bias are dynamically assigned
self.weight = None
self.bias = None
# just dummy buffers, not used
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, x):
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
b, c = x.size(0), x.size(1)
running_mean = self.running_mean.repeat(b)
running_var = self.running_var.repeat(b)
# Apply instance norm
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
out = F.batch_norm(
x_reshaped, running_mean, running_var, self.weight, self.bias,
True, self.momentum, self.eps)
return out.view(b, c, *x.size()[2:])
def __repr__(self):
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super(LayerNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
# print(x.size())
if x.size(0) == 1:
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
mean = x.view(-1).mean().view(*shape)
std = x.view(-1).std().view(*shape)
else:
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
x = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
return x
def l2normalize(v, eps=1e-12):
return v / (v.norm() + eps)
class SpectralNorm(nn.Module):
"""
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
"""
def __init__(self, module, name='weight', power_iterations=1):
super(SpectralNorm, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
if not self._made_params():
self._make_params()
def _update_u_v(self):
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
height = w.data.shape[0]
for _ in range(self.power_iterations):
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
sigma = u.dot(w.view(height, -1).mv(v))
setattr(self.module, self.name, w / sigma.expand_as(w))
def _made_params(self):
try:
u = getattr(self.module, self.name + "_u")
v = getattr(self.module, self.name + "_v")
w = getattr(self.module, self.name + "_bar")
return True
except AttributeError:
return False
def _make_params(self):
w = getattr(self.module, self.name)
height = w.data.shape[0]
width = w.view(height, -1).data.shape[1]
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
u.data = l2normalize(u.data)
v.data = l2normalize(v.data)
w_bar = nn.Parameter(w.data)
del self.module._parameters[self.name]
self.module.register_parameter(self.name + "_u", u)
self.module.register_parameter(self.name + "_v", v)
self.module.register_parameter(self.name + "_bar", w_bar)
def forward(self, *args):
self._update_u_v()
return self.module.forward(*args)

@ -0,0 +1,326 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
def generate_discrete_label(inputs, label_nc):
pred_batch = []
size = inputs.size()
for input in inputs:
input = input.view(1, label_nc, size[2], size[3])
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
pred_batch.append(pred)
pred_batch = np.array(pred_batch)
pred_batch = torch.from_numpy(pred_batch)
label_map = []
for p in pred_batch:
p = p.view(1, 512, 512)
label_map.append(p)
label_map = torch.stack(label_map, 0)
size = label_map.size()
oneHot_size = (size[0], label_nc, size[2], size[3])
if torch.cuda.is_available():
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
else:
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
return input_label
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, use_gan_feat_loss, use_vgg_loss, True, True, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, gb_gan, gb_gan_feat, gb_vgg, d_real, d_fake, d_blend):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,gb_gan,gb_gan_feat,gb_vgg,d_real,d_fake,d_blend),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
# Main Generator
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = input_nc + opt.output_nc
netB_input_nc = opt.output_nc * 2
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
print (pretrained_path)
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain:
self.load_network(self.netB, 'B', opt.which_epoch, pretrained_path)
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','GB_GAN','GB_GAN_Feat','GB_VGG','D_real','D_fake','D_blend')
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
import sys
if sys.version_info >= (3,0):
finetune_list = set()
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
if key.startswith('model' + str(opt.n_local_enhancers)):
params += [value]
finetune_list.add(key.split('.')[0])
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else:
params = list(self.netG.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer G + B
params = list(self.netG.parameters()) + list(self.netB.parameters())
self.optimizer_GB = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def encode_input(self, inter_label_map_1, label_map, inter_label_map_2, real_image, label_map_ref, real_image_ref, infer=False):
if self.opt.label_nc == 0:
if torch.cuda.is_available():
input_label = label_map.data.cuda()
inter_label_1 = inter_label_map_1.data.cuda()
inter_label_2 = inter_label_map_2.data.cuda()
input_label_ref = label_map_ref.data.cuda()
else:
input_label = label_map.data
inter_label_1 = inter_label_map_1.data
inter_label_2 = inter_label_map_2.data
input_label_ref = label_map_ref.data
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
if torch.cuda.is_available():
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
inter_label_1 = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
inter_label_1 = inter_label_1.scatter_(1, inter_label_map_1.data.long().cuda(), 1.0)
inter_label_2 = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
inter_label_2 = inter_label_2.scatter_(1, inter_label_map_2.data.long().cuda(), 1.0)
input_label_ref = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long().cuda(), 1.0)
else:
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
inter_label_1 = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
inter_label_1 = inter_label_1.scatter_(1, inter_label_map_1.data.long(), 1.0)
inter_label_2 = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
inter_label_2 = inter_label_2.scatter_(1, inter_label_map_2.data.long(), 1.0)
input_label_ref = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
inter_label_1 = inter_label_1.half()
inter_label_2 = inter_label_2.half()
input_label_ref = input_label_ref.half()
input_label = Variable(input_label, volatile=infer)
inter_label_1 = Variable(inter_label_1, volatile=infer)
inter_label_2 = Variable(inter_label_2, volatile=infer)
input_label_ref = Variable(input_label_ref, volatile=infer)
if torch.cuda.is_available():
real_image = Variable(real_image.data.cuda())
real_image_ref = Variable(real_image_ref.data.cuda())
else:
real_image = Variable(real_image.data)
real_image_ref = Variable(real_image_ref.data)
return inter_label_1, input_label, inter_label_2, real_image, input_label_ref, real_image_ref
def encode_input_test(self, label_map, label_map_ref, real_image_ref, infer=False):
if self.opt.label_nc == 0:
if torch.cuda.is_available():
input_label = label_map.data.cuda()
input_label_ref = label_map_ref.data.cuda()
else:
input_label = label_map.data
input_label_ref = label_map_ref.data
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
if torch.cuda.is_available():
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
input_label_ref = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long().cuda(), 1.0)
real_image_ref = Variable(real_image_ref.data.cuda())
else:
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
input_label_ref = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long(), 1.0)
real_image_ref = Variable(real_image_ref.data)
if self.opt.data_type == 16:
input_label = input_label.half()
input_label_ref = input_label_ref.half()
input_label = Variable(input_label, volatile=infer)
input_label_ref = Variable(input_label_ref, volatile=infer)
return input_label, input_label_ref, real_image_ref
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, inter_label_1, label, inter_label_2, image, label_ref, image_ref, infer=False):
# Encode Inputs
inter_label_1, input_label, inter_label_2, real_image, input_label_ref, real_image_ref = self.encode_input(inter_label_1, label, inter_label_2, image, label_ref, image_ref)
fake_inter_1 = self.netG.forward(inter_label_1, input_label, real_image)
fake_image = self.netG.forward(input_label, input_label, real_image)
fake_inter_2 = self.netG.forward(inter_label_2, input_label, real_image)
blend_image, alpha = self.netB.forward(fake_inter_1, fake_inter_2)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
pred_blend_pool = self.discriminate(input_label, blend_image, use_pool=True)
loss_D_blend = self.criterionGAN(pred_blend_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
pred_blend = self.netD.forward(torch.cat((input_label, blend_image), dim=1))
loss_GB_GAN = self.criterionGAN(pred_blend, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
loss_GB_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
loss_GB_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_blend[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
loss_GB_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG += self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
loss_GB_VGG += self.criterionVGG(blend_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_GB_GAN, loss_GB_GAN_Feat, loss_GB_VGG, loss_D_real, loss_D_fake, loss_D_blend ), None if not infer else fake_inter_1, fake_image, fake_inter_2, blend_image, alpha, real_image, inter_label_1, input_label, inter_label_2 ]
def inference(self, label, label_ref, image_ref):
# Encode Inputs
image_ref = Variable(image_ref)
input_label, input_label_ref, real_image_ref = self.encode_input_test(Variable(label), Variable(label_ref), image_ref, infer=True)
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref)
else:
fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref)
return fake_image
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
self.save_network(self.netB, 'B', which_epoch, self.gpu_ids)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label = inp
return self.inference(label)

@ -0,0 +1,89 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
# experiment specifics
self.parser.add_argument('--name', type=str, default='label2face_512p', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
self.parser.add_argument('--label_nc', type=int, default=19, help='# of input label channels')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
# for setting inputs
self.parser.add_argument('--dataroot', type=str, default='../Data_preprocessing/')
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
# for displays
self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
# for generator
self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
self.parser.add_argument('--n_blocks_global', type=int, default=4, help='number of residual blocks in the global generator network')
self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
self.initialized = True
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
# if len(self.opt.gpu_ids) > 0:
# torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
if save and not self.opt.continue_train:
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt

@ -0,0 +1,19 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=1000, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
self.isTrain = False

@ -0,0 +1,36 @@
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
# for displays
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
# for training
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/label2face_512p', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate for adam')
# for discriminators
self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
self.isTrain = True

@ -0,0 +1,31 @@
import random
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images

@ -0,0 +1,107 @@
from __future__ import print_function
print ('?')
import torch
import numpy as np
from PIL import Image
# import numpy as np
import os
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
image_numpy = image_tensor.cpu().float().numpy()
#if normalize:
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
#else:
# image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = (image_numpy + 1) / 2.0
image_numpy = np.clip(image_numpy, 0, 1)
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
image_numpy = image_numpy[:,:,0]
return image_numpy
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8):
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
#label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
label_numpy = label_tensor.numpy()
label_numpy = label_numpy / 255.0
return label_numpy
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
def labelcolormap(N):
if N == 35: # cityscape
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
dtype=np.uint8)
else:
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r, g, b = 0, 0, 0
id = i
for j in range(7):
str_id = uint82bin(id)
r = r ^ (np.uint8(str_id[-1]) << (7-j))
g = g ^ (np.uint8(str_id[-2]) << (7-j))
b = b ^ (np.uint8(str_id[-3]) << (7-j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
return cmap
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image

@ -0,0 +1,131 @@
# DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better
Code for this paper [DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better](https://arxiv.org/abs/1908.03826)
Orest Kupyn, Tetiana Martyniuk, Junru Wu, Zhangyang Wang
In ICCV 2019
## Overview
We present a new end-to-end generative adversarial network (GAN) for single image motion deblurring, named
DeblurGAN-v2, which considerably boosts state-of-the-art deblurring efficiency, quality, and flexibility. DeblurGAN-v2
is based on a relativistic conditional GAN with a double-scale discriminator. For the first time, we introduce the
Feature Pyramid Network into deblurring, as a core building block in the generator of DeblurGAN-v2. It can flexibly
work with a wide range of backbones, to navigate the balance between performance and efficiency. The plug-in of
sophisticated backbones (e.g., Inception-ResNet-v2) can lead to solid state-of-the-art deblurring. Meanwhile,
with light-weight backbones (e.g., MobileNet and its variants), DeblurGAN-v2 reaches 10-100 times faster than
the nearest competitors, while maintaining close to state-of-the-art results, implying the option of real-time
video deblurring. We demonstrate that DeblurGAN-v2 obtains very competitive performance on several popular
benchmarks, in terms of deblurring quality (both objective and subjective), as well as efficiency. Besides,
we show the architecture to be effective for general image restoration tasks too.
<!---We also study the effect of DeblurGAN-v2 on the task of general image restoration - enhancement of images degraded
jointly by noise, blur, compression, etc. The picture below shows the visual quality superiority of DeblurGAN-v2 with
Inception-ResNet-v2 backbone over DeblurGAN. It is drawn from our new synthesized Restore Dataset
(refer to Datasets subsection below).-->
![](./doc_images/kohler_visual.png)
![](./doc_images/restore_visual.png)
![](./doc_images/gopro_table.png)
![](./doc_images/lai_table.png)
<!---![](./doc_images/dvd_table.png)-->
<!---![](./doc_images/kohler_table.png)-->
## DeblurGAN-v2 Architecture
![](./doc_images/pipeline.jpg)
<!---Our architecture consists of an FPN backbone from which we take five final feature maps of different scales as the
output. Those features are later up-sampled to the same 1/4 input size and concatenated into one tensor which contains
the semantic information on different levels. We additionally add two upsampling and convolutional layers at the end of
the network to restore the original image size and reduce artifacts. We also introduce a direct skip connection from
the input to the output, so that the learning focuses on the residue. The input images are normalized to \[-1, 1\].
e also use a **tanh** activation layer to keep the output in the same range.-->
<!---The new FPN-embeded architecture is agnostic to the choice of feature extractor backbones. With this plug-and-play
property, we are entitled with the flexibility to navigate through the spectrum of accuracy and efficiency.
By default, we choose ImageNet-pretrained backbones to convey more semantic-related features.-->
## Datasets
The datasets for training can be downloaded via the links below:
- [DVD](https://drive.google.com/file/d/1bpj9pCcZR_6-AHb5aNnev5lILQbH8GMZ/view)
- [GoPro](https://drive.google.com/file/d/1KStHiZn5TNm2mo3OLZLjnRvd0vVFCI0W/view)
- [NFS](https://drive.google.com/file/d/1Ut7qbQOrsTZCUJA_mJLptRMipD8sJzjy/view)
## Training
#### Command
```python train.py```
training script will load config under config/config.yaml
#### Tensorboard visualization
![](./doc_images/tensorboard2.png)
## Testing
To test on a single image,
```python predict.py IMAGE_NAME.jpg```
By default, the name of the pretrained model used by Predictor is 'best_fpn.h5'. One can change it in the code ('weights_path' argument). It assumes that the fpn_inception backbone is used. If you want to try it with different backbone pretrain, please specify it also under ['model']['g_name'] in config/config.yaml.
## Pre-trained models
<table align="center">
<tr>
<th>Dataset</th>
<th>G Model</th>
<th>D Model</th>
<th>Loss Type</th>
<th>PSNR/ SSIM</th>
<th>Link</th>
</tr>
<tr>
<td rowspan="3">GoPro Test Dataset</td>
<td>InceptionResNet-v2</td>
<td>double_gan</td>
<td>ragan-ls</td>
<td>29.55/ 0.934</td>
<td><a href="">https://drive.google.com/open?id=1UXcsRVW-6KF23_TNzxw-xC0SzaMfXOaR</a></td>
</tr>
<tr>
<td>MobileNet</td>
<td>double_gan</td>
<td>ragan-ls</td>
<td>28.17/ 0.925</td>
<td><a href="">https://drive.google.com/open?id=1JhnT4BBeKBBSLqTo6UsJ13HeBXevarrU</a></td>
</tr>
<tr>
<td>MobileNet-DSC</td>
<td>double_gan</td>
<td>ragan-ls</td>
<td>28.03/ 0.922</td>
<td><a href=""></a></td>
</tr>
</table>
## Parent Repository
The code was taken from <a href="">https://github.com/KupynOrest/RestoreGAN</a> . This repository contains flexible pipelines for different Image Restoration tasks.
## Citation
If you use this code for your research, please cite our paper.
```
```
@InProceedings{Kupyn_2019_ICCV,
author = {Orest Kupyn and Tetiana Martyniuk and Junru Wu and Zhangyang Wang},
title = {DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2019}
}
```
```

@ -0,0 +1,99 @@
import torch
import copy
class GANFactory:
factories = {}
def __init__(self):
pass
def add_factory(gan_id, model_factory):
GANFactory.factories.put[gan_id] = model_factory
add_factory = staticmethod(add_factory)
# A Template Method:
def create_model(gan_id, net_d=None, criterion=None):
if gan_id not in GANFactory.factories:
GANFactory.factories[gan_id] = \
eval(gan_id + '.Factory()')
return GANFactory.factories[gan_id].create(net_d, criterion)
create_model = staticmethod(create_model)
class GANTrainer(object):
def __init__(self, net_d, criterion):
self.net_d = net_d
self.criterion = criterion
def loss_d(self, pred, gt):
pass
def loss_g(self, pred, gt):
pass
def get_params(self):
pass
class NoGAN(GANTrainer):
def __init__(self, net_d, criterion):
GANTrainer.__init__(self, net_d, criterion)
def loss_d(self, pred, gt):
return [0]
def loss_g(self, pred, gt):
return 0
def get_params(self):
return [torch.nn.Parameter(torch.Tensor(1))]
class Factory:
@staticmethod
def create(net_d, criterion): return NoGAN(net_d, criterion)
class SingleGAN(GANTrainer):
def __init__(self, net_d, criterion):
GANTrainer.__init__(self, net_d, criterion)
self.net_d = self.net_d.cuda()
def loss_d(self, pred, gt):
return self.criterion(self.net_d, pred, gt)
def loss_g(self, pred, gt):
return self.criterion.get_g_loss(self.net_d, pred, gt)
def get_params(self):
return self.net_d.parameters()
class Factory:
@staticmethod
def create(net_d, criterion): return SingleGAN(net_d, criterion)
class DoubleGAN(GANTrainer):
def __init__(self, net_d, criterion):
GANTrainer.__init__(self, net_d, criterion)
self.patch_d = net_d['patch'].cuda()
self.full_d = net_d['full'].cuda()
self.full_criterion = copy.deepcopy(criterion)
def loss_d(self, pred, gt):
return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2
def loss_g(self, pred, gt):
return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred,
gt)) / 2
def get_params(self):
return list(self.patch_d.parameters()) + list(self.full_d.parameters())
class Factory:
@staticmethod
def create(net_d, criterion): return DoubleGAN(net_d, criterion)

@ -0,0 +1,93 @@
from typing import List
import albumentations as albu
def get_transforms(size, scope = 'geometric', crop='random'):
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
albu.ElasticTransform(),
albu.OpticalDistortion(),
albu.OneOf([
albu.CLAHE(clip_limit=2),
albu.IAASharpen(),
albu.IAAEmboss(),
albu.RandomBrightnessContrast(),
albu.RandomGamma()
], p=0.5),
albu.OneOf([
albu.RGBShift(),
albu.HueSaturationValue(),
], p=0.5),
]),
'weak': albu.Compose([albu.HorizontalFlip(),
]),
'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True),
albu.ShiftScaleRotate(always_apply=True),
albu.Transpose(always_apply=True),
albu.OpticalDistortion(always_apply=True),
albu.ElasticTransform(always_apply=True),
])
}
aug_fn = augs[scope]
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
pad = albu.PadIfNeeded(size, size)
pipeline = albu.Compose([aug_fn, crop_fn, pad], additional_targets={'target': 'image'})
def process(a, b):
r = pipeline(image=a, target=b)
return r['image'], r['target']
return process
def get_normalize():
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
def process(a, b):
r = normalize(image=a, target=b)
return r['image'], r['target']
return process
def _resolve_aug_fn(name):
d = {
'cutout': albu.Cutout,
'rgb_shift': albu.RGBShift,
'hsv_shift': albu.HueSaturationValue,
'motion_blur': albu.MotionBlur,
'median_blur': albu.MedianBlur,
'snow': albu.RandomSnow,
'shadow': albu.RandomShadow,
'fog': albu.RandomFog,
'brightness_contrast': albu.RandomBrightnessContrast,
'gamma': albu.RandomGamma,
'sun_flare': albu.RandomSunFlare,
'sharpen': albu.IAASharpen,
'jpeg': albu.JpegCompression,
'gray': albu.ToGray,
# ToDo: pixelize
# ToDo: partial gray
}
return d[name]
def get_corrupt_function(config):
augs = []
for aug_params in config:
name = aug_params.pop('name')
cls = _resolve_aug_fn(name)
prob = aug_params.pop('prob') if 'prob' in aug_params else .5
augs.append(cls(p=prob, **aug_params))
augs = albu.OneOf(augs)
def process(x):
return augs(image=x)['image']
return process

Binary file not shown.

@ -0,0 +1,68 @@
---
project: deblur_gan
experiment_desc: fpn
train:
files_a: &FILES_A /datasets/my_dataset/**/*.jpg
files_b: *FILES_A
size: &SIZE 256
crop: random
preload: &PRELOAD false
preload_size: &PRELOAD_SIZE 0
bounds: [0, .9]
scope: geometric
corrupt: &CORRUPT
- name: cutout
prob: 0.5
num_holes: 3
max_h_size: 25
max_w_size: 25
- name: jpeg
quality_lower: 70
quality_upper: 90
- name: motion_blur
- name: median_blur
- name: gamma
- name: rgb_shift
- name: hsv_shift
- name: sharpen
val:
files_a: *FILES_A
files_b: *FILES_A
size: *SIZE
scope: geometric
crop: center
preload: *PRELOAD
preload_size: *PRELOAD_SIZE
bounds: [.9, 1]
corrupt: *CORRUPT
phase: train
warmup_num: 3
model:
g_name: fpn_inception
blocks: 9
d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
d_layers: 3
content_loss: perceptual
adv_lambda: 0.001
disc_loss: wgan-gp
learn_residual: True
norm_layer: instance
dropout: True
num_epochs: 200
train_batches_per_epoch: 1000
val_batches_per_epoch: 100
batch_size: 1
image_size: [256, 256]
optimizer:
name: adam
lr: 0.0001
scheduler:
name: linear
start_epoch: 50
min_lr: 0.0000001

@ -0,0 +1,142 @@
import os
from copy import deepcopy
from functools import partial
from glob import glob
from hashlib import sha1
from typing import Callable, Iterable, Optional, Tuple
import cv2
import numpy as np
from glog import logger
from joblib import Parallel, cpu_count, delayed
from skimage.io import imread
from torch.utils.data import Dataset
from tqdm import tqdm
import aug
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
data = list(data)
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
lower_bound, upper_bound = [x * n_buckets for x in bounds]
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
if salt:
msg += f'; salt is {salt}'
if verbose:
logger.info(msg)
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
path_a, path_b = x
names = ''.join(map(os.path.basename, (path_a, path_b)))
return sha1(f'{names}_{salt}'.encode()).hexdigest()
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
hashes = map(partial(hash_fn, salt=salt), data)
return np.array([int(x, 16) % n_buckets for x in hashes])
def _read_img(x: str):
img = cv2.imread(x)
if img is None:
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
img = imread(x)
return img
class PairedDataset(Dataset):
def __init__(self,
files_a: Tuple[str],
files_b: Tuple[str],
transform_fn: Callable,
normalize_fn: Callable,
corrupt_fn: Optional[Callable] = None,
preload: bool = True,
preload_size: Optional[int] = 0,
verbose=True):
assert len(files_a) == len(files_b)
self.preload = preload
self.data_a = files_a
self.data_b = files_b
self.verbose = verbose
self.corrupt_fn = corrupt_fn
self.transform_fn = transform_fn
self.normalize_fn = normalize_fn
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
if preload:
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
if files_a == files_b:
self.data_a = self.data_b = preload_fn(self.data_a)
else:
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
self.preload = True
def _bulk_preload(self, data: Iterable[str], preload_size: int):
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
@staticmethod
def _preload(x: str, preload_size: int):
img = _read_img(x)
if preload_size:
h, w, *_ = img.shape
h_scale = preload_size / h
w_scale = preload_size / w
scale = max(h_scale, w_scale)
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
return img
def _preprocess(self, img, res):
def transpose(x):
return np.transpose(x, (2, 0, 1))
return map(transpose, self.normalize_fn(img, res))
def __len__(self):
return len(self.data_a)
def __getitem__(self, idx):
a, b = self.data_a[idx], self.data_b[idx]
if not self.preload:
a, b = map(_read_img, (a, b))
a, b = self.transform_fn(a, b)
if self.corrupt_fn is not None:
a = self.corrupt_fn(a)
a, b = self._preprocess(a, b)
return {'a': a, 'b': b}
@staticmethod
def from_config(config):
config = deepcopy(config)
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
normalize_fn = aug.get_normalize()
corrupt_fn = aug.get_corrupt_function(config['corrupt'])
hash_fn = hash_from_paths
# ToDo: add more hash functions
verbose = config.get('verbose', True)
data = subsample(data=zip(files_a, files_b),
bounds=config.get('bounds', (0, 1)),
hash_fn=hash_fn,
verbose=verbose)
files_a, files_b = map(list, zip(*data))
return PairedDataset(files_a=files_a,
files_b=files_b,
preload=config['preload'],
preload_size=config['preload_size'],
corrupt_fn=corrupt_fn,
normalize_fn=normalize_fn,
transform_fn=transform_fn,
verbose=verbose)

@ -0,0 +1,56 @@
import logging
from collections import defaultdict
import numpy as np
from tensorboardX import SummaryWriter
WINDOW_SIZE = 100
class MetricCounter:
def __init__(self, exp_name):
self.writer = SummaryWriter(exp_name)
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
self.metrics = defaultdict(list)
self.images = defaultdict(list)
self.best_metric = 0
def add_image(self, x: np.ndarray, tag: str):
self.images[tag].append(x)
def clear(self):
self.metrics = defaultdict(list)
self.images = defaultdict(list)
def add_losses(self, l_G, l_content, l_D=0):
for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
(l_G, l_content, l_G - l_content, l_D)):
self.metrics[name].append(value)
def add_metrics(self, psnr, ssim):
for name, value in zip(('PSNR', 'SSIM'),
(psnr, ssim)):
self.metrics[name].append(value)
def loss_message(self):
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
def write_to_tensorboard(self, epoch_num, validation=False):
scalar_prefix = 'Validation' if validation else 'Train'
for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'):
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
for tag in self.images:
imgs = self.images[tag]
if imgs:
imgs = np.array(imgs)
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
global_step=epoch_num)
self.images[tag] = []
def update_best_model(self):
cur_metric = np.mean(self.metrics['PSNR'])
if self.best_metric < cur_metric:
self.best_metric = cur_metric
return True
return False

@ -0,0 +1,135 @@
import torch
import torch.nn as nn
from torchvision.models import resnet50, densenet121, densenet201
class FPNSegHead(nn.Module):
def __init__(self, num_in, num_mid, num_out):
super().__init__()
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = nn.functional.relu(self.block0(x), inplace=True)
x = nn.functional.relu(self.block1(x), inplace=True)
return x
class FPNDense(nn.Module):
def __init__(self, output_ch=3, num_filters=128, num_filters_fpn=256, pretrained=True):
super().__init__()
# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
# The segmentation heads on top of the FPN
self.head1 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
self.head2 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
self.head3 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
self.head4 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
self.smooth = nn.Sequential(
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
nn.BatchNorm2d(num_filters),
nn.ReLU(),
)
self.smooth2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
nn.BatchNorm2d(num_filters // 2),
nn.ReLU(),
)
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
def forward(self, x):
map0, map1, map2, map3, map4 = self.fpn(x)
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
smoothed = self.smooth2(smoothed + map0)
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
final = self.final(smoothed)
nn.Tanh(final)
class FPN(nn.Module):
def __init__(self, num_filters=256, pretrained=True):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""
super().__init__()
self.features = densenet121(pretrained=pretrained).features
self.enc0 = nn.Sequential(self.features.conv0,
self.features.norm0,
self.features.relu0)
self.pool0 = self.features.pool0
self.enc1 = self.features.denseblock1 # 256
self.enc2 = self.features.denseblock2 # 512
self.enc3 = self.features.denseblock3 # 1024
self.enc4 = self.features.denseblock4 # 2048
self.norm = self.features.norm5 # 2048
self.tr1 = self.features.transition1 # 256
self.tr2 = self.features.transition2 # 512
self.tr3 = self.features.transition3 # 1024
self.lateral4 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False)
self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False)
self.lateral0 = nn.Conv2d(64, num_filters // 2, kernel_size=1, bias=False)
def forward(self, x):
# Bottom-up pathway, from ResNet
enc0 = self.enc0(x)
pooled = self.pool0(enc0)
enc1 = self.enc1(pooled) # 256
tr1 = self.tr1(enc1)
enc2 = self.enc2(tr1) # 512
tr2 = self.tr2(enc2)
enc3 = self.enc3(tr2) # 1024
tr3 = self.tr3(enc3)
enc4 = self.enc4(tr3) # 2048
enc4 = self.norm(enc4)
# Lateral connections
lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)
lateral0 = self.lateral0(enc0)
# Top-down pathway
map4 = lateral4
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
return lateral0, map1, map2, map3, map4

@ -0,0 +1,167 @@
import torch
import torch.nn as nn
# from pretrainedmodels import inceptionresnetv2
# from torchsummary import summary
import torch.nn.functional as F
class FPNHead(nn.Module):
def __init__(self, num_in, num_mid, num_out):
super(FPNHead,self).__init__()
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = nn.functional.relu(self.block0(x), inplace=True)
x = nn.functional.relu(self.block1(x), inplace=True)
return x
class ConvBlock(nn.Module):
def __init__(self, num_in, num_out, norm_layer):
super().__init__()
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
norm_layer(num_out),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.block(x)
return x
class FPNInception(nn.Module):
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
super(FPNInception,self).__init__()
# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
# The segmentation heads on top of the FPN
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.smooth = nn.Sequential(
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(),
)
self.smooth2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
norm_layer(num_filters // 2),
nn.ReLU(),
)
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
def unfreeze(self):
self.fpn.unfreeze()
def forward(self, x):
map0, map1, map2, map3, map4 = self.fpn(x)
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
smoothed = self.smooth2(smoothed + map0)
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
final = self.final(smoothed)
res = torch.tanh(final) + x
return torch.clamp(res, min = -1,max = 1)
class FPN(nn.Module):
def __init__(self, norm_layer, num_filters=256):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""
super(FPN,self).__init__()
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
# self.inception = torch.load('inceptionresnetv2-520b38e4.pth')
self.enc0 = self.inception.conv2d_1a
self.enc1 = nn.Sequential(
self.inception.conv2d_2a,
self.inception.conv2d_2b,
self.inception.maxpool_3a,
) # 64
self.enc2 = nn.Sequential(
self.inception.conv2d_3b,
self.inception.conv2d_4a,
self.inception.maxpool_5a,
) # 192
self.enc3 = nn.Sequential(
self.inception.mixed_5b,
self.inception.repeat,
self.inception.mixed_6a,
) # 1088
self.enc4 = nn.Sequential(
self.inception.repeat_1,
self.inception.mixed_7a,
) #2080
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.pad = nn.ReflectionPad2d(1)
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
for param in self.inception.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.inception.parameters():
param.requires_grad = True
def forward(self, x):
# Bottom-up pathway, from ResNet
enc0 = self.enc0(x)
enc1 = self.enc1(enc0) # 256
enc2 = self.enc2(enc1) # 512
enc3 = self.enc3(enc2) # 1024
enc4 = self.enc4(enc3) # 2048
# Lateral connections
lateral4 = self.pad(self.lateral4(enc4))
lateral3 = self.pad(self.lateral3(enc3))
lateral2 = self.lateral2(enc2)
lateral1 = self.pad(self.lateral1(enc1))
lateral0 = self.lateral0(enc0)
# Top-down pathway
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
pad1 = (0, 1, 0, 1)
map4 = lateral4
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4

@ -0,0 +1,160 @@
import torch
import torch.nn as nn
from pretrainedmodels import inceptionresnetv2
from torchsummary import summary
import torch.nn.functional as F
class FPNHead(nn.Module):
def __init__(self, num_in, num_mid, num_out):
super().__init__()
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = nn.functional.relu(self.block0(x), inplace=True)
x = nn.functional.relu(self.block1(x), inplace=True)
return x
class ConvBlock(nn.Module):
def __init__(self, num_in, num_out, norm_layer):
super().__init__()
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
norm_layer(num_out),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.block(x)
return x
class FPNInceptionSimple(nn.Module):
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
super().__init__()
# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
# The segmentation heads on top of the FPN
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.smooth = nn.Sequential(
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(),
)
self.smooth2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
norm_layer(num_filters // 2),
nn.ReLU(),
)
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
def unfreeze(self):
self.fpn.unfreeze()
def forward(self, x):
map0, map1, map2, map3, map4 = self.fpn(x)
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
smoothed = self.smooth2(smoothed + map0)
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
final = self.final(smoothed)
res = torch.tanh(final) + x
return torch.clamp(res, min = -1,max = 1)
class FPN(nn.Module):
def __init__(self, norm_layer, num_filters=256):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""
super().__init__()
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
self.enc0 = self.inception.conv2d_1a
self.enc1 = nn.Sequential(
self.inception.conv2d_2a,
self.inception.conv2d_2b,
self.inception.maxpool_3a,
) # 64
self.enc2 = nn.Sequential(
self.inception.conv2d_3b,
self.inception.conv2d_4a,
self.inception.maxpool_5a,
) # 192
self.enc3 = nn.Sequential(
self.inception.mixed_5b,
self.inception.repeat,
self.inception.mixed_6a,
) # 1088
self.enc4 = nn.Sequential(
self.inception.repeat_1,
self.inception.mixed_7a,
) #2080
self.pad = nn.ReflectionPad2d(1)
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
for param in self.inception.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.inception.parameters():
param.requires_grad = True
def forward(self, x):
# Bottom-up pathway, from ResNet
enc0 = self.enc0(x)
enc1 = self.enc1(enc0) # 256
enc2 = self.enc2(enc1) # 512
enc3 = self.enc3(enc2) # 1024
enc4 = self.enc4(enc3) # 2048
# Lateral connections
lateral4 = self.pad(self.lateral4(enc4))
lateral3 = self.pad(self.lateral3(enc3))
lateral2 = self.lateral2(enc2)
lateral1 = self.pad(self.lateral1(enc1))
lateral0 = self.lateral0(enc0)
# Top-down pathway
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
pad1 = (0, 1, 0, 1)
map4 = lateral4
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
map2 = F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4

@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from mobilenet_v2 import MobileNetV2
class FPNHead(nn.Module):
def __init__(self, num_in, num_mid, num_out):
super().__init__()
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = nn.functional.relu(self.block0(x), inplace=True)
x = nn.functional.relu(self.block1(x), inplace=True)
return x
class FPNMobileNet(nn.Module):
def __init__(self, norm_layer, output_ch=3, num_filters=64, num_filters_fpn=128, pretrained=True):
super().__init__()
# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer = norm_layer, pretrained=pretrained)
# The segmentation heads on top of the FPN
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.smooth = nn.Sequential(
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(),
)
self.smooth2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
norm_layer(num_filters // 2),
nn.ReLU(),
)
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
def unfreeze(self):
self.fpn.unfreeze()
def forward(self, x):
map0, map1, map2, map3, map4 = self.fpn(x)
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
smoothed = self.smooth2(smoothed + map0)
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
final = self.final(smoothed)
res = torch.tanh(final) + x
return torch.clamp(res, min=-1, max=1)
class FPN(nn.Module):
def __init__(self, norm_layer, num_filters=128, pretrained=True):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""
super().__init__()
net = MobileNetV2(n_class=1000)
if pretrained:
#Load weights into the project directory
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
net.load_state_dict(state_dict)
self.features = net.features
self.enc0 = nn.Sequential(*self.features[0:2])
self.enc1 = nn.Sequential(*self.features[2:4])
self.enc2 = nn.Sequential(*self.features[4:7])
self.enc3 = nn.Sequential(*self.features[7:11])
self.enc4 = nn.Sequential(*self.features[11:16])
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
for param in self.features.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.features.parameters():
param.requires_grad = True
def forward(self, x):
# Bottom-up pathway, from ResNet
enc0 = self.enc0(x)
enc1 = self.enc1(enc0) # 256
enc2 = self.enc2(enc1) # 512
enc3 = self.enc3(enc2) # 1024
enc4 = self.enc4(enc3) # 2048
# Lateral connections
lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)
lateral0 = self.lateral0(enc0)
# Top-down pathway
map4 = lateral4
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
return lateral0, map1, map2, map3, map4

@ -0,0 +1,147 @@
import torch
import torch.nn as nn
from models.mobilenet_v2 import MobileNetV2
class FPNHead(nn.Module):
def __init__(self, num_in, num_mid, num_out):
super().__init__()
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = nn.functional.relu(self.block0(x), inplace=True)
x = nn.functional.relu(self.block1(x), inplace=True)
return x
class FPNMobileNet(nn.Module):
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=128, pretrained=True):
super().__init__()
# Feature Pyramid Network (FPN) with four feature maps of resolutions
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
# The segmentation heads on top of the FPN
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
self.smooth = nn.Sequential(
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(),
)
self.smooth2 = nn.Sequential(
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
norm_layer(num_filters // 2),
nn.ReLU(),
)
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
def unfreeze(self):
self.fpn.unfreeze()
def forward(self, x):
map0, map1, map2, map3, map4 = self.fpn(x)
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
smoothed = self.smooth2(smoothed + map0)
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
final = self.final(smoothed)
res = torch.tanh(final) + x
return torch.clamp(res, min=-1, max=1)
class FPN(nn.Module):
def __init__(self, num_filters=128, pretrained=True):
"""Creates an `FPN` instance for feature extraction.
Args:
num_filters: the number of filters in each output pyramid level
pretrained: use ImageNet pre-trained backbone feature extractor
"""
super().__init__()
net = MobileNetV2(n_class=1000)
if pretrained:
#Load weights into the project directory
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
net.load_state_dict(state_dict)
self.features = net.features
self.enc0 = nn.Sequential(*self.features[0:2])
self.enc1 = nn.Sequential(*self.features[2:4])
self.enc2 = nn.Sequential(*self.features[4:7])
self.enc3 = nn.Sequential(*self.features[7:11])
self.enc4 = nn.Sequential(*self.features[11:16])
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
norm_layer(num_filters),
nn.ReLU(inplace=True))
self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
for param in self.features.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.features.parameters():
param.requires_grad = True
def forward(self, x):
# Bottom-up pathway, from ResNet
enc0 = self.enc0(x)
enc1 = self.enc1(enc0) # 256
enc2 = self.enc2(enc1) # 512
enc3 = self.enc3(enc2) # 1024
enc4 = self.enc4(enc3) # 2048
# Lateral connections
lateral4 = self.lateral4(enc4)
lateral3 = self.lateral3(enc3)
lateral2 = self.lateral2(enc2)
lateral1 = self.lateral1(enc1)
lateral0 = self.lateral0(enc0)
# Top-down pathway
map4 = lateral4
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
return lateral0, map1, map2, map3, map4

@ -0,0 +1,300 @@
import torch
import torch.autograd as autograd
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from util.image_pool import ImagePool
###############################################################################
# Functions
###############################################################################
class ContentLoss():
def initialize(self, loss):
self.criterion = loss
def get_loss(self, fakeIm, realIm):
return self.criterion(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class PerceptualLoss():
def contentFunc(self):
conv_3_3_layer = 14
cnn = models.vgg19(pretrained=True).features
cnn = cnn.cuda()
model = nn.Sequential()
model = model.cuda()
model = model.eval()
for i, layer in enumerate(list(cnn)):
model.add_module(str(i), layer)
if i == conv_3_3_layer:
break
return model
def initialize(self, loss):
with torch.no_grad():
self.criterion = loss
self.contentFunc = self.contentFunc()
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def get_loss(self, fakeIm, realIm):
fakeIm = (fakeIm + 1) / 2.0
realIm = (realIm + 1) / 2.0
fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
f_fake = self.contentFunc.forward(fakeIm)
f_real = self.contentFunc.forward(realIm)
f_real_no_grad = f_real.detach()
loss = self.criterion(f_fake, f_real_no_grad)
return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
def __call__(self, fakeIm, realIm):
return self.get_loss(fakeIm, realIm)
class GANLoss(nn.Module):
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_l1:
self.loss = nn.L1Loss()
else:
self.loss = nn.BCEWithLogitsLoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor.cuda()
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
class DiscLoss(nn.Module):
def name(self):
return 'DiscLoss'
def __init__(self):
super(DiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_AB_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
pred_fake = net.forward(fakeB)
return self.criterionGAN(pred_fake, 1)
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.pred_fake = net.forward(fakeB.detach())
self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
# Real
self.pred_real = net.forward(realB)
self.loss_D_real = self.criterionGAN(self.pred_real, 1)
# Combined loss
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLoss(nn.Module):
def name(self):
return 'RelativisticDiscLoss'
def __init__(self):
super(RelativisticDiscLoss, self).__init__()
self.criterionGAN = GANLoss(use_l1=False)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class RelativisticDiscLossLS(nn.Module):
def name(self):
return 'RelativisticDiscLossLS'
def __init__(self):
super(RelativisticDiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
self.real_pool = ImagePool(50)
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.pred_fake = net.forward(fakeB)
# Real
self.pred_real = net.forward(realB)
errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
return errG
def get_loss(self, net, fakeB, realB):
# Fake
# stop backprop to the generator by detaching fake_B
# Generated Image Disc Output should be close to zero
self.fake_B = fakeB.detach()
self.real_B = realB
self.pred_fake = net.forward(fakeB.detach())
self.fake_pool.add(self.pred_fake)
# Real
self.pred_real = net.forward(realB)
self.real_pool.add(self.pred_real)
# Combined loss
self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
return self.loss_D
def __call__(self, net, fakeB, realB):
return self.get_loss(net, fakeB, realB)
class DiscLossLS(DiscLoss):
def name(self):
return 'DiscLossLS'
def __init__(self):
super(DiscLossLS, self).__init__()
self.criterionGAN = GANLoss(use_l1=True)
def get_g_loss(self, net, fakeB, realB):
return DiscLoss.get_g_loss(self, net, fakeB)
def get_loss(self, net, fakeB, realB):
return DiscLoss.get_loss(self, net, fakeB, realB)
class DiscLossWGANGP(DiscLossLS):
def name(self):
return 'DiscLossWGAN-GP'
def __init__(self):
super(DiscLossWGANGP, self).__init__()
self.LAMBDA = 10
def get_g_loss(self, net, fakeB, realB):
# First, G(A) should fake the discriminator
self.D_fake = net.forward(fakeB)
return -self.D_fake.mean()
def calc_gradient_penalty(self, netD, real_data, fake_data):
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.cuda()
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = interpolates.cuda()
interpolates = Variable(interpolates, requires_grad=True)
disc_interpolates = netD.forward(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
return gradient_penalty
def get_loss(self, net, fakeB, realB):
self.D_fake = net.forward(fakeB.detach())
self.D_fake = self.D_fake.mean()
# Real
self.D_real = net.forward(realB)
self.D_real = self.D_real.mean()
# Combined loss
self.loss_D = self.D_fake - self.D_real
gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
return self.loss_D + gradient_penalty
def get_loss(model):
if model['content_loss'] == 'perceptual':
content_loss = PerceptualLoss()
content_loss.initialize(nn.MSELoss())
elif model['content_loss'] == 'l1':
content_loss = ContentLoss()
content_loss.initialize(nn.L1Loss())
else:
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
if model['disc_loss'] == 'wgan-gp':
disc_loss = DiscLossWGANGP()
elif model['disc_loss'] == 'lsgan':
disc_loss = DiscLossLS()
elif model['disc_loss'] == 'gan':
disc_loss = DiscLoss()
elif model['disc_loss'] == 'ragan':
disc_loss = RelativisticDiscLoss()
elif model['disc_loss'] == 'ragan-ls':
disc_loss = RelativisticDiscLossLS()
else:
raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
return content_loss, disc_loss

@ -0,0 +1,126 @@
import torch.nn as nn
import math
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
assert input_size % 32 == 0
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
else:
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, n_class),
)
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.mean(3).mean(2)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()

@ -0,0 +1,35 @@
import numpy as np
import torch.nn as nn
from skimage.measure import compare_ssim as SSIM
from util.metrics import PSNR
class DeblurModel(nn.Module):
def __init__(self):
super(DeblurModel, self).__init__()
def get_input(self, data):
img = data['a']
inputs = img
targets = data['b']
inputs, targets = inputs.cuda(), targets.cuda()
return inputs, targets
def tensor2im(self, image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def get_images_and_metrics(self, inp, output, target):
inp = self.tensor2im(inp)
fake = self.tensor2im(output.data)
real = self.tensor2im(target.data)
psnr = PSNR(fake, real)
ssim = SSIM(fake, real, multichannel=True)
vis_img = np.hstack((inp, fake, real))
return psnr, ssim, vis_img
def get_model(model_config):
return DeblurModel()

@ -0,0 +1,330 @@
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np
from fpn_mobilenet import FPNMobileNet
from fpn_inception import FPNInception
# from fpn_inception_simple import FPNInceptionSimple
from unet_seresnext import UNetSEResNext
from fpn_densenet import FPNDense
###############################################################################
# Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
##############################################################################
# Classes
##############################################################################
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.use_parallel = use_parallel
self.learn_residual = learn_residual
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
output = self.model(input)
if self.learn_residual:
output = input + output
output = torch.clamp(output,min = -1,max = 1)
return output
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class DicsriminatorTail(nn.Module):
def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
super(DicsriminatorTail, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence = [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
class MultiScaleDiscriminator(nn.Module):
def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
super(MultiScaleDiscriminator, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, 3):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
self.scale_one = nn.Sequential(*sequence)
self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
nf_mult_prev = 4
nf_mult = 8
self.scale_two = nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True))
nf_mult_prev = nf_mult
self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
self.scale_three = nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True))
self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
def forward(self, input):
x = self.scale_one(input)
x_1 = self.first_tail(x)
x = self.scale_two(x)
x_2 = self.second_tail(x)
x = self.scale_three(x)
x = self.third_tail(x)
return [x_1, x_2, x]
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
super(NLayerDiscriminator, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
def get_fullD(model_config):
model_d = NLayerDiscriminator(n_layers=5,
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
return model_d
def get_generator(model_config):
generator_name = model_config['g_name']
if generator_name == 'resnet':
model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_dropout=model_config['dropout'],
n_blocks=model_config['blocks'],
learn_residual=model_config['learn_residual'])
elif generator_name == 'fpn_mobilenet':
model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
elif generator_name == 'fpn_inception':
# model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
# torch.save(model_g, 'mymodel.pth')
model_g = torch.load('mymodel.pth')
elif generator_name == 'fpn_inception_simple':
model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
elif generator_name == 'fpn_dense':
model_g = FPNDense()
elif generator_name == 'unet_seresnext':
model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
pretrained=model_config['pretrained'])
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)
return nn.DataParallel(model_g)
def get_generator_new(weights_path):
model_g = torch.load(weights_path+'mymodel.pth')
return nn.DataParallel(model_g)
def get_discriminator(model_config):
discriminator_name = model_config['d_name']
if discriminator_name == 'no_gan':
model_d = None
elif discriminator_name == 'patch_gan':
model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
model_d = nn.DataParallel(model_d)
elif discriminator_name == 'double_gan':
patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
patch_gan = nn.DataParallel(patch_gan)
full_gan = get_fullD(model_config)
full_gan = nn.DataParallel(full_gan)
model_d = {'patch': patch_gan,
'full': full_gan}
elif discriminator_name == 'multi_scale':
model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
model_d = nn.DataParallel(model_d)
else:
raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
return model_d
def get_nets(model_config):
return get_generator(model_config), get_discriminator(model_config)

@ -0,0 +1,430 @@
from __future__ import print_function, division, absolute_import
from collections import OrderedDict
import math
import torch.nn as nn
from torch.utils import model_zoo
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
'se_resnext50_32x4d', 'se_resnext101_32x4d']
pretrained_settings = {
'senet154': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet50': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet101': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnet152': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext50_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
'se_resnext101_32x4d': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
'input_space': 'RGB',
'input_size': [3, 224, 224],
'input_range': [0, 1],
'mean': [0.485, 0.456, 0.406],
'std': [0.229, 0.224, 0.225],
'num_classes': 1000
}
},
}
class SEModule(nn.Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class Bottleneck(nn.Module):
"""
Base class for bottlenecks that implements `forward()` method.
"""
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = self.se_module(out) + residual
out = self.relu(out)
return out
class SEBottleneck(Bottleneck):
"""
Bottleneck for SENet154.
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
super(SEBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1)
self.bn1 = nn.InstanceNorm2d(planes * 2, affine=False)
self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
stride=stride, padding=1, groups=groups)
self.bn2 = nn.InstanceNorm2d(planes * 4, affine=False)
self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1)
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNetBottleneck(Bottleneck):
"""
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
implementation and uses `stride=stride` in `conv1` and not in `conv2`
(the latter is used in the torchvision implementation of ResNet).
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None):
super(SEResNetBottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1,
stride=stride)
self.bn1 = nn.InstanceNorm2d(planes, affine=False)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
groups=groups)
self.bn2 = nn.InstanceNorm2d(planes, affine=False)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1)
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SEResNeXtBottleneck(Bottleneck):
"""
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
"""
expansion = 4
def __init__(self, inplanes, planes, groups, reduction, stride=1,
downsample=None, base_width=4):
super(SEResNeXtBottleneck, self).__init__()
width = math.floor(planes * (base_width / 64)) * groups
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1,
stride=1)
self.bn1 = nn.InstanceNorm2d(width, affine=False)
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
padding=1, groups=groups)
self.bn2 = nn.InstanceNorm2d(width, affine=False)
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1)
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
self.relu = nn.ReLU(inplace=True)
self.se_module = SEModule(planes * 4, reduction=reduction)
self.downsample = downsample
self.stride = stride
class SENet(nn.Module):
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
inplanes=128, input_3x3=True, downsample_kernel_size=3,
downsample_padding=1, num_classes=1000):
"""
Parameters
----------
block (nn.Module): Bottleneck class.
- For SENet154: SEBottleneck
- For SE-ResNet models: SEResNetBottleneck
- For SE-ResNeXt models: SEResNeXtBottleneck
layers (list of ints): Number of residual blocks for 4 layers of the
network (layer1...layer4).
groups (int): Number of groups for the 3x3 convolution in each
bottleneck block.
- For SENet154: 64
- For SE-ResNet models: 1
- For SE-ResNeXt models: 32
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
- For all models: 16
dropout_p (float or None): Drop probability for the Dropout layer.
If `None` the Dropout layer is not used.
- For SENet154: 0.2
- For SE-ResNet models: None
- For SE-ResNeXt models: None
inplanes (int): Number of input channels for layer1.
- For SENet154: 128
- For SE-ResNet models: 64
- For SE-ResNeXt models: 64
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
a single 7x7 convolution in layer0.
- For SENet154: True
- For SE-ResNet models: False
- For SE-ResNeXt models: False
downsample_kernel_size (int): Kernel size for downsampling convolutions
in layer2, layer3 and layer4.
- For SENet154: 3
- For SE-ResNet models: 1
- For SE-ResNeXt models: 1
downsample_padding (int): Padding for downsampling convolutions in
layer2, layer3 and layer4.
- For SENet154: 1
- For SE-ResNet models: 0
- For SE-ResNeXt models: 0
num_classes (int): Number of outputs in `last_linear` layer.
- For all models: 1000
"""
super(SENet, self).__init__()
self.inplanes = inplanes
if input_3x3:
layer0_modules = [
('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1)),
('bn1', nn.InstanceNorm2d(64, affine=False)),
('relu1', nn.ReLU(inplace=True)),
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1)),
('bn2', nn.InstanceNorm2d(64, affine=False)),
('relu2', nn.ReLU(inplace=True)),
('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1)),
('bn3', nn.InstanceNorm2d(inplanes, affine=False)),
('relu3', nn.ReLU(inplace=True)),
]
else:
layer0_modules = [
('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
padding=3)),
('bn1', nn.InstanceNorm2d(inplanes, affine=False)),
('relu1', nn.ReLU(inplace=True)),
]
# To preserve compatibility with Caffe weights `ceil_mode=True`
# is used instead of `padding=1`.
layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
ceil_mode=True)))
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
self.layer1 = self._make_layer(
block,
planes=64,
blocks=layers[0],
groups=groups,
reduction=reduction,
downsample_kernel_size=1,
downsample_padding=0
)
self.layer2 = self._make_layer(
block,
planes=128,
blocks=layers[1],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.layer3 = self._make_layer(
block,
planes=256,
blocks=layers[2],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.layer4 = self._make_layer(
block,
planes=512,
blocks=layers[3],
stride=2,
groups=groups,
reduction=reduction,
downsample_kernel_size=downsample_kernel_size,
downsample_padding=downsample_padding
)
self.avg_pool = nn.AvgPool2d(7, stride=1)
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
downsample_kernel_size=1, downsample_padding=0):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=downsample_kernel_size, stride=stride,
padding=downsample_padding),
nn.InstanceNorm2d(planes * block.expansion, affine=False),
)
layers = []
layers.append(block(self.inplanes, planes, groups, reduction, stride,
downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, groups, reduction))
return nn.Sequential(*layers)
def features(self, x):
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def logits(self, x):
x = self.avg_pool(x)
if self.dropout is not None:
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
def forward(self, x):
x = self.features(x)
x = self.logits(x)
return x
def initialize_pretrained_model(model, num_classes, settings):
assert num_classes == settings['num_classes'], \
'num_classes should be {}, but is {}'.format(
settings['num_classes'], num_classes)
model.load_state_dict(model_zoo.load_url(settings['url']))
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']
def senet154(num_classes=1000, pretrained='imagenet'):
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
dropout_p=0.2, num_classes=num_classes)
if pretrained is not None:
settings = pretrained_settings['senet154'][pretrained]
initialize_pretrained_model(model, num_classes, settings)
return model
def se_resnet50(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained is not None:
settings = pretrained_settings['se_resnet50'][pretrained]
initialize_pretrained_model(model, num_classes, settings)
return model
def se_resnet101(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained is not None:
settings = pretrained_settings['se_resnet101'][pretrained]
initialize_pretrained_model(model, num_classes, settings)
return model
def se_resnet152(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained is not None:
settings = pretrained_settings['se_resnet152'][pretrained]
initialize_pretrained_model(model, num_classes, settings)
return model
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
return model
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)
if pretrained is not None:
settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
initialize_pretrained_model(model, num_classes, settings)
return model

@ -0,0 +1,153 @@
import torch
from torch import nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
from torch.nn import Sequential
from collections import OrderedDict
import torchvision
from torch.nn import functional as F
from senet import se_resnext50_32x4d
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super(ConvRelu, self).__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class UNetSEResNext(nn.Module):
def __init__(self, num_classes=3, num_filters=32,
pretrained=True, is_deconv=True):
super().__init__()
self.num_classes = num_classes
pretrain = 'imagenet' if pretrained is True else None
self.encoder = se_resnext50_32x4d(num_classes=1000, pretrained=pretrain)
bottom_channel_nr = 2048
self.conv1 = self.encoder.layer0
#self.se_e1 = SCSEBlock(64)
self.conv2 = self.encoder.layer1
#self.se_e2 = SCSEBlock(64 * 4)
self.conv3 = self.encoder.layer2
#self.se_e3 = SCSEBlock(128 * 4)
self.conv4 = self.encoder.layer3
#self.se_e4 = SCSEBlock(256 * 4)
self.conv5 = self.encoder.layer4
#self.se_e5 = SCSEBlock(512 * 4)
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 2, is_deconv)
#self.se_d5 = SCSEBlock(num_filters * 2)
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 2, num_filters * 8, num_filters * 2, is_deconv)
#self.se_d4 = SCSEBlock(num_filters * 2)
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 2, num_filters * 4, num_filters * 2, is_deconv)
#self.se_d3 = SCSEBlock(num_filters * 2)
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2, num_filters * 2, is_deconv)
#self.se_d2 = SCSEBlock(num_filters * 2)
self.dec1 = DecoderBlockV(num_filters * 2, num_filters, num_filters * 2, is_deconv)
#self.se_d1 = SCSEBlock(num_filters * 2)
self.dec0 = ConvRelu(num_filters * 10, num_filters * 2)
self.final = nn.Conv2d(num_filters * 2, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
#conv1 = self.se_e1(conv1)
conv2 = self.conv2(conv1)
#conv2 = self.se_e2(conv2)
conv3 = self.conv3(conv2)
#conv3 = self.se_e3(conv3)
conv4 = self.conv4(conv3)
#conv4 = self.se_e4(conv4)
conv5 = self.conv5(conv4)
#conv5 = self.se_e5(conv5)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
#dec5 = self.se_d5(dec5)
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
#dec4 = self.se_d4(dec4)
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
#dec3 = self.se_d3(dec3)
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
#dec2 = self.se_d2(dec2)
dec1 = self.dec1(dec2)
#dec1 = self.se_d1(dec1)
f = torch.cat((
dec1,
F.upsample(dec2, scale_factor=2, mode='bilinear', align_corners=False),
F.upsample(dec3, scale_factor=4, mode='bilinear', align_corners=False),
F.upsample(dec4, scale_factor=8, mode='bilinear', align_corners=False),
F.upsample(dec5, scale_factor=16, mode='bilinear', align_corners=False),
), 1)
dec0 = self.dec0(f)
return self.final(dec0)
class DecoderBlockV(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.InstanceNorm2d(out_channels, affine=False),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class DecoderCenter(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderCenter, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.InstanceNorm2d(out_channels, affine=False),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels)
)
def forward(self, x):
return self.block(x)

@ -0,0 +1,108 @@
import os
from glob import glob
# from typing import Optional
import cv2
import numpy as np
import torch
import yaml
from fire import Fire
from tqdm import tqdm
from aug import get_normalize
from models.networks import get_generator
class Predictor:
def __init__(self, weights_path, model_name=''):
with open('config/config.yaml') as cfg:
config = yaml.load(cfg)
model = get_generator(model_name or config['model'])
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
if torch.cuda.is_available():
self.model = model.cuda()
else:
self.model = model
self.model.train(True)
# GAN inference should be in train mode to use actual stats in norm layers,
# it's not a bug
self.normalize_fn = get_normalize()
@staticmethod
def _array_to_batch(x):
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def _preprocess(self, x, mask):
x, _ = self.normalize_fn(x, x)
if mask is None:
mask = np.ones_like(x, dtype=np.float32)
else:
mask = np.round(mask.astype('float32') / 255)
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(self._array_to_batch, (x, mask)), h, w
@staticmethod
def _postprocess(x):
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def __call__(self, img, mask, ignore_mask=True):
(img, mask), h, w = self._preprocess(img, mask)
with torch.no_grad():
if torch.cuda.is_available():
inputs = [img.cuda()]
else:
inputs = [img]
if not ignore_mask:
inputs += [mask]
pred = self.model(*inputs)
return self._postprocess(pred)[:h, :w, :]
def sorted_glob(pattern):
return sorted(glob(pattern))
def main(img_pattern,
mask_pattern = None,
weights_path='best_fpn.h5',
out_dir='submit/',
side_by_side = False):
imgs = sorted_glob(img_pattern)
masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs]
pairs = zip(imgs, masks)
names = sorted([os.path.basename(x) for x in glob(img_pattern)])
predictor = Predictor(weights_path=weights_path)
# os.makedirs(out_dir)
for name, pair in tqdm(zip(names, pairs), total=len(names)):
f_img, f_mask = pair
img, mask = map(cv2.imread, (f_img, f_mask))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pred = predictor(img, mask)
if side_by_side:
pred = np.hstack((img, pred))
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(out_dir, name),
pred)
if __name__ == '__main__':
Fire(main)

@ -0,0 +1,67 @@
from models.networks import get_generator_new
from aug import get_normalize
import torch
import numpy as np
config={'project': 'deblur_gan', 'warmup_num': 3, 'optimizer': {'lr': 0.0001, 'name': 'adam'}, 'val': {'preload': False, 'bounds': [0.9, 1], 'crop': 'center', 'files_b': '/datasets/my_dataset/**/*.jpg', 'files_a': '/datasets/my_dataset/**/*.jpg', 'scope': 'geometric', 'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5}, {'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'}, {'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'}, {'name': 'hsv_shift'}, {'name': 'sharpen'}], 'preload_size': 0, 'size': 256}, 'val_batches_per_epoch': 100, 'num_epochs': 200, 'batch_size': 1, 'experiment_desc': 'fpn', 'train_batches_per_epoch': 1000, 'train': {'preload': False, 'bounds': [0, 0.9], 'crop': 'random', 'files_b': '/datasets/my_dataset/**/*.jpg', 'files_a': '/datasets/my_dataset/**/*.jpg', 'preload_size': 0, 'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5}, {'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'}, {'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'}, {'name': 'hsv_shift'}, {'name': 'sharpen'}], 'scope': 'geometric', 'size': 256}, 'scheduler': {'min_lr': 1e-07, 'name': 'linear', 'start_epoch': 50}, 'image_size': [256, 256], 'phase': 'train', 'model': {'d_name': 'double_gan', 'disc_loss': 'wgan-gp', 'blocks': 9, 'content_loss': 'perceptual', 'adv_lambda': 0.001, 'dropout': True, 'g_name': 'fpn_inception', 'd_layers': 3, 'learn_residual': True, 'norm_layer': 'instance'}}
class Predictor:
def __init__(self, weights_path, model_name=''):
# model = get_generator(model_name or config['model'])
model = get_generator_new(weights_path[0:-11])
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
if torch.cuda.is_available():
self.model = model.cuda()
else:
self.model = model
self.model.train(True)
# GAN inference should be in train mode to use actual stats in norm layers,
# it's not a bug
self.normalize_fn = get_normalize()
@staticmethod
def _array_to_batch(x):
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def _preprocess(self, x, mask):
x, _ = self.normalize_fn(x, x)
if mask is None:
mask = np.ones_like(x, dtype=np.float32)
else:
mask = np.round(mask.astype('float32') / 255)
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(self._array_to_batch, (x, mask)), h, w
@staticmethod
def _postprocess(x):
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def __call__(self, img, mask, ignore_mask=True):
(img, mask), h, w = self._preprocess(img, mask)
with torch.no_grad():
if torch.cuda.is_available():
inputs = [img.cuda()]
else:
inputs = [img]
if not ignore_mask:
inputs += [mask]
pred = self.model(*inputs)
return self._postprocess(pred)[:h, :w, :]

@ -0,0 +1,13 @@
torch==1.0.1
torchvision
pretrainedmodels
numpy
opencv-python-headless
joblib
albumentations
scikit-image
tqdm
glog
tensorboardx
fire
# this file is not ready yet

@ -0,0 +1,59 @@
import math
from torch.optim import lr_scheduler
class WarmRestart(lr_scheduler.CosineAnnealingLR):
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
This can't support scheduler.step(epoch). please keep epoch=None.
"""
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
"""implements SGDR
Parameters:
----------
T_max : int
Maximum number of epochs.
T_mult : int
Multiplicative factor of T_max.
eta_min : int
Minimum learning rate. Default: 0.
last_epoch : int
The index of last epoch. Default: -1.
"""
self.T_mult = T_mult
super().__init__(optimizer, T_max, eta_min, last_epoch)
def get_lr(self):
if self.last_epoch == self.T_max:
self.last_epoch = 0
self.T_max *= self.T_mult
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
base_lr in self.base_lrs]
class LinearDecay(lr_scheduler._LRScheduler):
"""This class implements LinearDecay
"""
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
"""implements LinearDecay
Parameters:
----------
"""
self.num_epochs = num_epochs
self.start_epoch = start_epoch
self.min_lr = min_lr
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.start_epoch:
return self.base_lrs
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
base_lr in self.base_lrs]

@ -0,0 +1,3 @@
#!/usr/bin/env bash
python3 -m unittest discover $(pwd)

@ -0,0 +1,20 @@
import unittest
import numpy as np
from aug import get_transforms
class AugTest(unittest.TestCase):
@staticmethod
def make_images():
img = (np.random.rand(100, 100, 3) * 255).astype('uint8')
return img.copy(), img.copy()
def test_aug(self):
for scope in ('strong', 'weak'):
for crop in ('random', 'center'):
aug_pipeline = get_transforms(80, scope=scope, crop=crop)
a, b = self.make_images()
a, b = aug_pipeline(a, b)
np.testing.assert_allclose(a, b)

@ -0,0 +1,76 @@
import os
import unittest
from shutil import rmtree
from tempfile import mkdtemp
import cv2
import numpy as np
from torch.utils.data import DataLoader
from dataset import PairedDataset
def make_img():
return (np.random.rand(100, 100, 3) * 255).astype('uint8')
class AugTest(unittest.TestCase):
tmp_dir = mkdtemp()
raw = os.path.join(tmp_dir, 'raw')
gt = os.path.join(tmp_dir, 'gt')
def setUp(self):
for d in (self.raw, self.gt):
os.makedirs(d)
for i in range(5):
for d in (self.raw, self.gt):
img = make_img()
cv2.imwrite(os.path.join(d, f'{i}.png'), img)
def tearDown(self):
rmtree(self.tmp_dir)
def dataset_gen(self, equal=True):
base_config = {'files_a': os.path.join(self.raw, '*.png'),
'files_b': os.path.join(self.raw if equal else self.gt, '*.png'),
'size': 32,
}
for b in ([0, 1], [0, 0.9]):
for scope in ('strong', 'weak'):
for crop in ('random', 'center'):
for preload in (0, 1):
for preload_size in (0, 64):
config = base_config.copy()
config['bounds'] = b
config['scope'] = scope
config['crop'] = crop
config['preload'] = preload
config['preload_size'] = preload_size
config['verbose'] = False
dataset = PairedDataset.from_config(config)
yield dataset
def test_equal_datasets(self):
for dataset in self.dataset_gen(equal=True):
dataloader = DataLoader(dataset=dataset,
batch_size=2,
shuffle=True,
drop_last=True)
dataloader = iter(dataloader)
batch = next(dataloader)
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
np.testing.assert_allclose(a, b)
def test_datasets(self):
for dataset in self.dataset_gen(equal=False):
dataloader = DataLoader(dataset=dataset,
batch_size=2,
shuffle=True,
drop_last=True)
dataloader = iter(dataloader)
batch = next(dataloader)
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
assert not np.all(a == b), 'images should not be the same'

@ -0,0 +1,90 @@
from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torchvision import models, transforms
from torch.autograd import Variable
import shutil
import glob
import tqdm
from util.metrics import PSNR
from albumentations import Compose, CenterCrop, PadIfNeeded
from PIL import Image
from ssim.ssimlib import SSIM
from models.networks import get_generator
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
def prepare_dirs(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def get_gt_image(path):
dir, filename = os.path.split(path)
base, seq = os.path.split(dir)
base, _ = os.path.split(base)
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
return img
def test_image(model, image_path):
img_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
size_transform = Compose([
PadIfNeeded(736, 1280)
])
crop = CenterCrop(720, 1280)
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_s = size_transform(image=img)['image']
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
img_tensor = img_transforms(img_tensor)
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
result_image = model(img_tensor)
result_image = result_image[0].cpu().float().numpy()
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
result_image = crop(image=result_image)['image']
result_image = result_image.astype('uint8')
gt_image = get_gt_image(image_path)
_, filename = os.path.split(image_path)
psnr = PSNR(result_image, gt_image)
pilFake = Image.fromarray(result_image)
pilReal = Image.fromarray(gt_image)
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
return psnr, ssim
def test(model, files):
psnr = 0
ssim = 0
for file in tqdm.tqdm(files):
cur_psnr, cur_ssim = test_image(model, file)
psnr += cur_psnr
ssim += cur_ssim
print("PSNR = {}".format(psnr / len(files)))
print("SSIM = {}".format(ssim / len(files)))
if __name__ == '__main__':
args = get_args()
with open('config/config.yaml') as cfg:
config = yaml.load(cfg)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path)['model'])
model = model.cuda()
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
test(model, filenames)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save