mirror of https://github.com/kritiksoman/GIMP-ML
First upload
parent
ba244f12d8
commit
ccd3c980ca
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,100 @@
|
||||
## Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
|
||||
from data.image_folder import make_dataset, make_dataset_test
|
||||
from PIL import Image
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
class AlignedDataset(BaseDataset):
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.root = opt.dataroot
|
||||
|
||||
### input A (label maps)
|
||||
if opt.isTrain or opt.use_encoded_image:
|
||||
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
|
||||
self.A_paths = sorted(make_dataset(self.dir_A))
|
||||
self.AR_paths = make_dataset(self.dir_A)
|
||||
|
||||
### input A inter 1 (label maps)
|
||||
if opt.isTrain or opt.use_encoded_image:
|
||||
dir_A_inter_1 = '_label_inter_1'
|
||||
self.dir_A_inter_1 = os.path.join(opt.dataroot, opt.phase + dir_A_inter_1)
|
||||
self.A_paths_inter_1 = sorted(make_dataset(self.dir_A_inter_1))
|
||||
|
||||
### input A inter 2 (label maps)
|
||||
if opt.isTrain or opt.use_encoded_image:
|
||||
dir_A_inter_2 = '_label_inter_2'
|
||||
self.dir_A_inter_2 = os.path.join(opt.dataroot, opt.phase + dir_A_inter_2)
|
||||
self.A_paths_inter_2 = sorted(make_dataset(self.dir_A_inter_2))
|
||||
|
||||
### input A test (label maps)
|
||||
if not (opt.isTrain or opt.use_encoded_image):
|
||||
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
|
||||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
|
||||
self.A_paths = sorted(make_dataset_test(self.dir_A))
|
||||
dir_AR = '_AR' if self.opt.label_nc == 0 else '_labelref'
|
||||
self.dir_AR = os.path.join(opt.dataroot, opt.phase + dir_AR)
|
||||
self.AR_paths = sorted(make_dataset_test(self.dir_AR))
|
||||
|
||||
### input B (real images)
|
||||
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
|
||||
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
|
||||
self.B_paths = sorted(make_dataset(self.dir_B))
|
||||
self.BR_paths = sorted(make_dataset(self.dir_B))
|
||||
|
||||
self.dataset_size = len(self.A_paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
### input A (label maps)
|
||||
A_path = self.A_paths[index]
|
||||
AR_path = self.AR_paths[index]
|
||||
A = Image.open(A_path)
|
||||
AR = Image.open(AR_path)
|
||||
|
||||
if self.opt.isTrain:
|
||||
A_path_inter_1 = self.A_paths_inter_1[index]
|
||||
A_path_inter_2 = self.A_paths_inter_2[index]
|
||||
A_inter_1 = Image.open(A_path_inter_1)
|
||||
A_inter_2 = Image.open(A_path_inter_2)
|
||||
|
||||
params = get_params(self.opt, A.size)
|
||||
if self.opt.label_nc == 0:
|
||||
transform_A = get_transform(self.opt, params)
|
||||
A_tensor = transform_A(A.convert('RGB'))
|
||||
if self.opt.isTrain:
|
||||
A_inter_1_tensor = transform_A(A_inter_1.convert('RGB'))
|
||||
A_inter_2_tensor = transform_A(A_inter_2.convert('RGB'))
|
||||
AR_tensor = transform_A(AR.convert('RGB'))
|
||||
else:
|
||||
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
|
||||
A_tensor = transform_A(A) * 255.0
|
||||
if self.opt.isTrain:
|
||||
A_inter_1_tensor = transform_A(A_inter_1) * 255.0
|
||||
A_inter_2_tensor = transform_A(A_inter_2) * 255.0
|
||||
AR_tensor = transform_A(AR) * 255.0
|
||||
B_tensor = inst_tensor = feat_tensor = 0
|
||||
### input B (real images)
|
||||
B_path = self.B_paths[index]
|
||||
BR_path = self.BR_paths[index]
|
||||
B = Image.open(B_path).convert('RGB')
|
||||
BR = Image.open(BR_path).convert('RGB')
|
||||
transform_B = get_transform(self.opt, params)
|
||||
B_tensor = transform_B(B)
|
||||
BR_tensor = transform_B(BR)
|
||||
|
||||
if self.opt.isTrain:
|
||||
input_dict = {'inter_label_1': A_inter_1_tensor, 'label': A_tensor, 'inter_label_2': A_inter_2_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
|
||||
else:
|
||||
input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
|
||||
|
||||
return input_dict
|
||||
|
||||
def __len__(self):
|
||||
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
|
||||
|
||||
def name(self):
|
||||
return 'AlignedDataset'
|
@ -0,0 +1,14 @@
|
||||
|
||||
class BaseDataLoader():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
pass
|
||||
|
||||
def load_data():
|
||||
return None
|
||||
|
||||
|
||||
|
@ -0,0 +1,97 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
def name(self):
|
||||
return 'BaseDataset'
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_params(opt, size):
|
||||
w, h = size
|
||||
new_h = h
|
||||
new_w = w
|
||||
if opt.resize_or_crop == 'resize_and_crop':
|
||||
new_h = new_w = opt.loadSize
|
||||
elif opt.resize_or_crop == 'scale_width_and_crop':
|
||||
new_w = opt.loadSize
|
||||
new_h = opt.loadSize * h // w
|
||||
|
||||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
|
||||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
|
||||
|
||||
#flip = random.random() > 0.5
|
||||
flip = 0
|
||||
return {'crop_pos': (x, y), 'flip': flip}
|
||||
|
||||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True, normalize_mask=False):
|
||||
transform_list = []
|
||||
if 'resize' in opt.resize_or_crop:
|
||||
osize = [opt.loadSize, opt.loadSize]
|
||||
transform_list.append(transforms.Scale(osize, method))
|
||||
elif 'scale_width' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
|
||||
|
||||
if 'crop' in opt.resize_or_crop:
|
||||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
|
||||
|
||||
if opt.resize_or_crop == 'none':
|
||||
base = float(2 ** opt.n_downsample_global)
|
||||
if opt.netG == 'local':
|
||||
base *= (2 ** opt.n_local_enhancers)
|
||||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
|
||||
|
||||
transform_list += [transforms.ToTensor()]
|
||||
|
||||
if normalize:
|
||||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
if normalize_mask:
|
||||
transform_list += [transforms.Normalize((0, 0, 0),
|
||||
(1 / 255., 1 / 255., 1 / 255.))]
|
||||
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def normalize():
|
||||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
|
||||
def __make_power_2(img, base, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
h = int(round(oh / base) * base)
|
||||
w = int(round(ow / base) * base)
|
||||
if (h == oh) and (w == ow):
|
||||
return img
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __scale_width(img, target_width, method=Image.BICUBIC):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), method)
|
||||
|
||||
def __crop(img, pos, size):
|
||||
ow, oh = img.size
|
||||
x1, y1 = pos
|
||||
tw = th = size
|
||||
if (ow > tw or oh > th):
|
||||
return img.crop((x1, y1, x1 + tw, y1 + th))
|
||||
return img
|
||||
|
||||
def __flip(img, flip):
|
||||
if flip:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
Binary file not shown.
@ -0,0 +1,31 @@
|
||||
import torch.utils.data
|
||||
from data.base_data_loader import BaseDataLoader
|
||||
|
||||
|
||||
def CreateDataset(opt):
|
||||
dataset = None
|
||||
from data.aligned_dataset import AlignedDataset
|
||||
dataset = AlignedDataset()
|
||||
|
||||
print("dataset [%s] was created" % (dataset.name()))
|
||||
dataset.initialize(opt)
|
||||
return dataset
|
||||
|
||||
class CustomDatasetDataLoader(BaseDataLoader):
|
||||
def name(self):
|
||||
return 'CustomDatasetDataLoader'
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseDataLoader.initialize(self, opt)
|
||||
self.dataset = CreateDataset(opt)
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batchSize,
|
||||
shuffle=not opt.serial_batches,
|
||||
num_workers=int(opt.nThreads))
|
||||
|
||||
def load_data(self):
|
||||
return self.dataloader
|
||||
|
||||
def __len__(self):
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
@ -0,0 +1,7 @@
|
||||
|
||||
def CreateDataLoader(opt):
|
||||
from data.custom_dataset_data_loader import CustomDatasetDataLoader
|
||||
data_loader = CustomDatasetDataLoader()
|
||||
print(data_loader.name())
|
||||
data_loader.initialize(opt)
|
||||
return data_loader
|
Binary file not shown.
@ -0,0 +1,82 @@
|
||||
###############################################################################
|
||||
# Code from
|
||||
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
|
||||
# Modified the original code so that it also loads images from the current
|
||||
# directory as well as the subdirectories
|
||||
###############################################################################
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def make_dataset(dir):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
f = dir.split('/')[-1].split('_')[-1]
|
||||
print (dir, f)
|
||||
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
|
||||
if f == 'label' or f == '1' or f == '2':
|
||||
img = str(i) + '.png'
|
||||
else:
|
||||
img = str(i) + '.jpg'
|
||||
path = os.path.join(dir, img)
|
||||
#print(path)
|
||||
images.append(path)
|
||||
return images
|
||||
|
||||
def make_dataset_test(dir):
|
||||
images = []
|
||||
assert os.path.isdir(dir), '%s is not a valid directory' % dir
|
||||
|
||||
f = dir.split('/')[-1].split('_')[-1]
|
||||
for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
|
||||
if f == 'label' or f == 'labelref':
|
||||
img = str(i) + '.png'
|
||||
else:
|
||||
img = str(i) + '.jpg'
|
||||
path = os.path.join(dir, img)
|
||||
#print(path)
|
||||
images.append(path)
|
||||
return images
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " +
|
||||
",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,94 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import os
|
||||
import torch
|
||||
import sys
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def name(self):
|
||||
return 'BaseModel'
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
|
||||
def set_input(self, input):
|
||||
self.input = input
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
# used in test time, no backprop
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
pass
|
||||
|
||||
def optimize_parameters(self):
|
||||
pass
|
||||
|
||||
def get_current_visuals(self):
|
||||
return self.input
|
||||
|
||||
def get_current_errors(self):
|
||||
return {}
|
||||
|
||||
def save(self, label):
|
||||
pass
|
||||
|
||||
# helper saving function that can be used by subclasses
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
if len(gpu_ids) and torch.cuda.is_available():
|
||||
network.cuda()
|
||||
|
||||
# helper loading function that can be used by subclasses
|
||||
def load_network(self, network, network_label, epoch_label, save_dir=''):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
print (save_filename)
|
||||
if not save_dir:
|
||||
save_dir = self.save_dir
|
||||
save_path = os.path.join(save_dir, save_filename)
|
||||
if not os.path.isfile(save_path):
|
||||
print('%s not exists yet!' % save_path)
|
||||
if network_label == 'G':
|
||||
raise('Generator must exist!')
|
||||
else:
|
||||
#network.load_state_dict(torch.load(save_path))
|
||||
try:
|
||||
network.load_state_dict(torch.load(save_path))
|
||||
except:
|
||||
pretrained_dict = torch.load(save_path)
|
||||
model_dict = network.state_dict()
|
||||
try:
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
network.load_state_dict(pretrained_dict)
|
||||
if self.opt.verbose:
|
||||
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
|
||||
except:
|
||||
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
|
||||
for k, v in pretrained_dict.items():
|
||||
if v.size() == model_dict[k].size():
|
||||
model_dict[k] = v
|
||||
|
||||
if sys.version_info >= (3,0):
|
||||
not_initialized = set()
|
||||
else:
|
||||
from sets import Set
|
||||
not_initialized = Set()
|
||||
|
||||
for k, v in model_dict.items():
|
||||
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
|
||||
not_initialized.add(k.split('.')[0])
|
||||
|
||||
print(sorted(not_initialized))
|
||||
network.load_state_dict(model_dict)
|
||||
|
||||
def update_learning_rate():
|
||||
pass
|
Binary file not shown.
@ -0,0 +1,20 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import torch
|
||||
|
||||
def create_model(opt):
|
||||
if opt.model == 'pix2pixHD':
|
||||
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
|
||||
if opt.isTrain:
|
||||
model = Pix2PixHDModel()
|
||||
else:
|
||||
model = InferenceModel()
|
||||
|
||||
model.initialize(opt)
|
||||
if opt.verbose:
|
||||
print("model [%s] was created" % (model.name()))
|
||||
|
||||
if opt.isTrain and len(opt.gpu_ids):
|
||||
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
|
||||
|
||||
return model
|
Binary file not shown.
@ -0,0 +1,818 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv2d') != -1:
|
||||
m.weight.data.normal_(0.0, 0.02)
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
m.weight.data.normal_(1.0, 0.02)
|
||||
m.bias.data.fill_(0)
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
return norm_layer
|
||||
|
||||
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
||||
n_blocks_local=3, norm='instance', gpu_ids=[]):
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
if netG == 'global':
|
||||
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
||||
elif netG == 'local':
|
||||
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
|
||||
n_local_enhancers, n_blocks_local, norm_layer)
|
||||
else:
|
||||
raise('generator not implemented!')
|
||||
print(netG)
|
||||
# if len(gpu_ids) > 0:
|
||||
# assert(torch.cuda.is_available())
|
||||
# netG.cuda(gpu_ids[0])
|
||||
netG.apply(weights_init)
|
||||
return netG
|
||||
|
||||
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
|
||||
print(netD)
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
netD.cuda(gpu_ids[0])
|
||||
netD.apply(weights_init)
|
||||
return netD
|
||||
|
||||
def define_VAE(input_nc, gpu_ids=[]):
|
||||
netVAE = VAE(19, 32, 32, 1024)
|
||||
print(netVAE)
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
netVAE.cuda(gpu_ids[0])
|
||||
return netVAE
|
||||
|
||||
def define_B(input_nc, output_nc, ngf, n_downsample_global=3, n_blocks_global=3, norm='instance', gpu_ids=[]):
|
||||
norm_layer = get_norm_layer(norm_type=norm)
|
||||
netB = BlendGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
||||
print(netB)
|
||||
if len(gpu_ids) > 0:
|
||||
assert(torch.cuda.is_available())
|
||||
netB.cuda(gpu_ids[0])
|
||||
netB.apply(weights_init)
|
||||
return netB
|
||||
|
||||
def print_network(net):
|
||||
if isinstance(net, list):
|
||||
net = net[0]
|
||||
num_params = 0
|
||||
for param in net.parameters():
|
||||
num_params += param.numel()
|
||||
print(net)
|
||||
print('Total number of parameters: %d' % num_params)
|
||||
|
||||
##############################################################################
|
||||
# Losses
|
||||
##############################################################################
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_var = None
|
||||
self.fake_label_var = None
|
||||
self.Tensor = tensor
|
||||
if use_lsgan:
|
||||
self.loss = nn.MSELoss()
|
||||
else:
|
||||
self.loss = nn.BCELoss()
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
target_tensor = None
|
||||
if target_is_real:
|
||||
create_label = ((self.real_label_var is None) or
|
||||
(self.real_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
||||
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
||||
target_tensor = self.real_label_var
|
||||
else:
|
||||
create_label = ((self.fake_label_var is None) or
|
||||
(self.fake_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
||||
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
||||
target_tensor = self.fake_label_var
|
||||
return target_tensor
|
||||
|
||||
def __call__(self, input, target_is_real):
|
||||
if isinstance(input[0], list):
|
||||
loss = 0
|
||||
for input_i in input:
|
||||
pred = input_i[-1]
|
||||
target_tensor = self.get_target_tensor(pred, target_is_real)
|
||||
loss += self.loss(pred, target_tensor)
|
||||
return loss
|
||||
else:
|
||||
target_tensor = self.get_target_tensor(input[-1], target_is_real)
|
||||
return self.loss(input[-1], target_tensor)
|
||||
|
||||
class VGGLoss(nn.Module):
|
||||
def __init__(self, gpu_ids):
|
||||
super(VGGLoss, self).__init__()
|
||||
self.vgg = Vgg19().cuda()
|
||||
self.criterion = nn.L1Loss()
|
||||
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
||||
|
||||
def forward(self, x, y):
|
||||
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
|
||||
loss = 0
|
||||
for i in range(len(x_vgg)):
|
||||
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
||||
return loss
|
||||
|
||||
##############################################################################
|
||||
# Generator
|
||||
##############################################################################
|
||||
class GlobalGenerator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect'):
|
||||
assert(n_blocks >= 0)
|
||||
super(GlobalGenerator, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
||||
### downsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(ngf * mult * 2), activation]
|
||||
|
||||
### resnet blocks
|
||||
mult = 2**n_downsampling
|
||||
for i in range(n_blocks):
|
||||
model += [ResnetBlock(ngf * mult, norm_type='adain', padding_type=padding_type)]
|
||||
### upsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
norm_layer(int(ngf * mult / 2)), activation]
|
||||
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
# style encoder
|
||||
self.enc_style = StyleEncoder(5, 3, 16, self.get_num_adain_params(self.model), norm='none', activ='relu', pad_type='reflect')
|
||||
# label encoder
|
||||
self.enc_label = LabelEncoder(5, 19, 16, 64, norm='none', activ='relu', pad_type='reflect')
|
||||
|
||||
def assign_adain_params(self, adain_params, model):
|
||||
# assign the adain_params to the AdaIN layers in model
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
||||
mean = adain_params[:, :m.num_features]
|
||||
std = adain_params[:, m.num_features:2*m.num_features]
|
||||
m.bias = mean.contiguous().view(-1)
|
||||
m.weight = std.contiguous().view(-1)
|
||||
if adain_params.size(1) > 2*m.num_features:
|
||||
adain_params = adain_params[:, 2*m.num_features:]
|
||||
|
||||
def get_num_adain_params(self, model):
|
||||
# return the number of AdaIN parameters needed by the model
|
||||
num_adain_params = 0
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
||||
num_adain_params += 2*m.num_features
|
||||
return num_adain_params
|
||||
|
||||
def forward(self, input, input_ref, image_ref):
|
||||
fea1, fea2 = self.enc_label(input_ref)
|
||||
adain_params = self.enc_style((image_ref, fea1, fea2))
|
||||
self.assign_adain_params(adain_params, self.model)
|
||||
return self.model(input)
|
||||
|
||||
class BlendGenerator(nn.Module):
|
||||
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=3, norm_layer=nn.BatchNorm2d,
|
||||
padding_type='reflect'):
|
||||
assert(n_blocks >= 0)
|
||||
super(BlendGenerator, self).__init__()
|
||||
activation = nn.ReLU(True)
|
||||
|
||||
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
||||
### downsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(ngf * mult * 2), activation]
|
||||
|
||||
### resnet blocks
|
||||
mult = 2**n_downsampling
|
||||
for i in range(n_blocks):
|
||||
model += [ResnetBlock(ngf * mult, norm_type='in', padding_type=padding_type)]
|
||||
|
||||
### upsample
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
norm_layer(int(ngf * mult / 2)), activation]
|
||||
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Sigmoid()]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input1, input2):
|
||||
m = self.model(torch.cat([input1, input2], 1))
|
||||
return input1 * m + input2 * (1-m), m
|
||||
|
||||
# Define the Multiscale Discriminator.
|
||||
class MultiscaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
|
||||
use_sigmoid=False, num_D=3, getIntermFeat=False):
|
||||
super(MultiscaleDiscriminator, self).__init__()
|
||||
self.num_D = num_D
|
||||
self.n_layers = n_layers
|
||||
self.getIntermFeat = getIntermFeat
|
||||
|
||||
for i in range(num_D):
|
||||
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
|
||||
if getIntermFeat:
|
||||
for j in range(n_layers+2):
|
||||
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
|
||||
else:
|
||||
setattr(self, 'layer'+str(i), netD.model)
|
||||
|
||||
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def singleD_forward(self, model, input):
|
||||
if self.getIntermFeat:
|
||||
result = [input]
|
||||
for i in range(len(model)):
|
||||
result.append(model[i](result[-1]))
|
||||
return result[1:]
|
||||
else:
|
||||
return [model(input)]
|
||||
|
||||
def forward(self, input):
|
||||
num_D = self.num_D
|
||||
result = []
|
||||
input_downsampled = input
|
||||
for i in range(num_D):
|
||||
if self.getIntermFeat:
|
||||
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
|
||||
else:
|
||||
model = getattr(self, 'layer'+str(num_D-1-i))
|
||||
result.append(self.singleD_forward(model, input_downsampled))
|
||||
if i != (num_D-1):
|
||||
input_downsampled = self.downsample(input_downsampled)
|
||||
return result
|
||||
|
||||
# Define the PatchGAN discriminator with the specified arguments.
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
self.getIntermFeat = getIntermFeat
|
||||
self.n_layers = n_layers
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1.0)/2))
|
||||
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
|
||||
|
||||
nf = ndf
|
||||
for n in range(1, n_layers):
|
||||
nf_prev = nf
|
||||
nf = min(nf * 2, 512)
|
||||
sequence += [[
|
||||
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
|
||||
norm_layer(nf), nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
|
||||
nf_prev = nf
|
||||
nf = min(nf * 2, 512)
|
||||
sequence += [[
|
||||
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
|
||||
norm_layer(nf),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]]
|
||||
|
||||
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
||||
|
||||
if use_sigmoid:
|
||||
sequence += [[nn.Sigmoid()]]
|
||||
|
||||
if getIntermFeat:
|
||||
for n in range(len(sequence)):
|
||||
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
|
||||
else:
|
||||
sequence_stream = []
|
||||
for n in range(len(sequence)):
|
||||
sequence_stream += sequence[n]
|
||||
self.model = nn.Sequential(*sequence_stream)
|
||||
|
||||
def forward(self, input):
|
||||
if self.getIntermFeat:
|
||||
res = [input]
|
||||
for n in range(self.n_layers+2):
|
||||
model = getattr(self, 'model'+str(n))
|
||||
res.append(model(res[-1]))
|
||||
return res[1:]
|
||||
else:
|
||||
return self.model(input)
|
||||
|
||||
from torchvision import models
|
||||
class Vgg19(torch.nn.Module):
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
# Define the MaskVAE
|
||||
class VAE(nn.Module):
|
||||
def __init__(self, nc, ngf, ndf, latent_variable_size):
|
||||
super(VAE, self).__init__()
|
||||
#self.cuda = True
|
||||
self.nc = nc
|
||||
self.ngf = ngf
|
||||
self.ndf = ndf
|
||||
self.latent_variable_size = latent_variable_size
|
||||
|
||||
# encoder
|
||||
self.e1 = nn.Conv2d(nc, ndf, 4, 2, 1)
|
||||
self.bn1 = nn.BatchNorm2d(ndf)
|
||||
|
||||
self.e2 = nn.Conv2d(ndf, ndf*2, 4, 2, 1)
|
||||
self.bn2 = nn.BatchNorm2d(ndf*2)
|
||||
|
||||
self.e3 = nn.Conv2d(ndf*2, ndf*4, 4, 2, 1)
|
||||
self.bn3 = nn.BatchNorm2d(ndf*4)
|
||||
|
||||
self.e4 = nn.Conv2d(ndf*4, ndf*8, 4, 2, 1)
|
||||
self.bn4 = nn.BatchNorm2d(ndf*8)
|
||||
|
||||
self.e5 = nn.Conv2d(ndf*8, ndf*16, 4, 2, 1)
|
||||
self.bn5 = nn.BatchNorm2d(ndf*16)
|
||||
|
||||
self.e6 = nn.Conv2d(ndf*16, ndf*32, 4, 2, 1)
|
||||
self.bn6 = nn.BatchNorm2d(ndf*32)
|
||||
|
||||
self.e7 = nn.Conv2d(ndf*32, ndf*64, 4, 2, 1)
|
||||
self.bn7 = nn.BatchNorm2d(ndf*64)
|
||||
|
||||
self.fc1 = nn.Linear(ndf*64*4*4, latent_variable_size)
|
||||
self.fc2 = nn.Linear(ndf*64*4*4, latent_variable_size)
|
||||
|
||||
# decoder
|
||||
self.d1 = nn.Linear(latent_variable_size, ngf*64*4*4)
|
||||
|
||||
self.up1 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd1 = nn.ReplicationPad2d(1)
|
||||
self.d2 = nn.Conv2d(ngf*64, ngf*32, 3, 1)
|
||||
self.bn8 = nn.BatchNorm2d(ngf*32, 1.e-3)
|
||||
|
||||
self.up2 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd2 = nn.ReplicationPad2d(1)
|
||||
self.d3 = nn.Conv2d(ngf*32, ngf*16, 3, 1)
|
||||
self.bn9 = nn.BatchNorm2d(ngf*16, 1.e-3)
|
||||
|
||||
self.up3 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd3 = nn.ReplicationPad2d(1)
|
||||
self.d4 = nn.Conv2d(ngf*16, ngf*8, 3, 1)
|
||||
self.bn10 = nn.BatchNorm2d(ngf*8, 1.e-3)
|
||||
|
||||
self.up4 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd4 = nn.ReplicationPad2d(1)
|
||||
self.d5 = nn.Conv2d(ngf*8, ngf*4, 3, 1)
|
||||
self.bn11 = nn.BatchNorm2d(ngf*4, 1.e-3)
|
||||
|
||||
self.up5 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd5 = nn.ReplicationPad2d(1)
|
||||
self.d6 = nn.Conv2d(ngf*4, ngf*2, 3, 1)
|
||||
self.bn12 = nn.BatchNorm2d(ngf*2, 1.e-3)
|
||||
|
||||
self.up6 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd6 = nn.ReplicationPad2d(1)
|
||||
self.d7 = nn.Conv2d(ngf*2, ngf, 3, 1)
|
||||
self.bn13 = nn.BatchNorm2d(ngf, 1.e-3)
|
||||
|
||||
self.up7 = nn.UpsamplingNearest2d(scale_factor=2)
|
||||
self.pd7 = nn.ReplicationPad2d(1)
|
||||
self.d8 = nn.Conv2d(ngf, nc, 3, 1)
|
||||
|
||||
self.leakyrelu = nn.LeakyReLU(0.2)
|
||||
self.relu = nn.ReLU()
|
||||
#self.sigmoid = nn.Sigmoid()
|
||||
self.maxpool = nn.MaxPool2d((2, 2), (2, 2))
|
||||
|
||||
def encode(self, x):
|
||||
h1 = self.leakyrelu(self.bn1(self.e1(x)))
|
||||
h2 = self.leakyrelu(self.bn2(self.e2(h1)))
|
||||
h3 = self.leakyrelu(self.bn3(self.e3(h2)))
|
||||
h4 = self.leakyrelu(self.bn4(self.e4(h3)))
|
||||
h5 = self.leakyrelu(self.bn5(self.e5(h4)))
|
||||
h6 = self.leakyrelu(self.bn6(self.e6(h5)))
|
||||
h7 = self.leakyrelu(self.bn7(self.e7(h6)))
|
||||
h7 = h7.view(-1, self.ndf*64*4*4)
|
||||
return self.fc1(h7), self.fc2(h7)
|
||||
|
||||
def reparametrize(self, mu, logvar):
|
||||
std = logvar.mul(0.5).exp_()
|
||||
#if self.cuda:
|
||||
eps = torch.cuda.FloatTensor(std.size()).normal_()
|
||||
#else:
|
||||
# eps = torch.FloatTensor(std.size()).normal_()
|
||||
eps = Variable(eps)
|
||||
return eps.mul(std).add_(mu)
|
||||
|
||||
def decode(self, z):
|
||||
h1 = self.relu(self.d1(z))
|
||||
h1 = h1.view(-1, self.ngf*64, 4, 4)
|
||||
h2 = self.leakyrelu(self.bn8(self.d2(self.pd1(self.up1(h1)))))
|
||||
h3 = self.leakyrelu(self.bn9(self.d3(self.pd2(self.up2(h2)))))
|
||||
h4 = self.leakyrelu(self.bn10(self.d4(self.pd3(self.up3(h3)))))
|
||||
h5 = self.leakyrelu(self.bn11(self.d5(self.pd4(self.up4(h4)))))
|
||||
h6 = self.leakyrelu(self.bn12(self.d6(self.pd5(self.up5(h5)))))
|
||||
h7 = self.leakyrelu(self.bn13(self.d7(self.pd6(self.up6(h6)))))
|
||||
return self.d8(self.pd7(self.up7(h7)))
|
||||
|
||||
def get_latent_var(self, x):
|
||||
mu, logvar = self.encode(x)
|
||||
z = self.reparametrize(mu, logvar)
|
||||
return z, mu, logvar.mul(0.5).exp_()
|
||||
|
||||
def forward(self, x):
|
||||
mu, logvar = self.encode(x)
|
||||
z = self.reparametrize(mu, logvar)
|
||||
res = self.decode(z)
|
||||
|
||||
return res, x, mu, logvar
|
||||
|
||||
# style encode part
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
|
||||
super(StyleEncoder, self).__init__()
|
||||
self.model = []
|
||||
self.model_middle = []
|
||||
self.model_last = []
|
||||
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
for i in range(2):
|
||||
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
dim *= 2
|
||||
for i in range(n_downsample - 2):
|
||||
self.model_middle += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
self.model_last += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
||||
self.model_last += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
||||
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self.model_middle = nn.Sequential(*self.model_middle)
|
||||
self.model_last = nn.Sequential(*self.model_last)
|
||||
|
||||
self.output_dim = dim
|
||||
|
||||
self.sft1 = SFTLayer()
|
||||
self.sft2 = SFTLayer()
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.model(x[0])
|
||||
fea = self.sft1((fea, x[1]))
|
||||
fea = self.model_middle(fea)
|
||||
fea = self.sft2((fea, x[2]))
|
||||
return self.model_last(fea)
|
||||
|
||||
# label encode part
|
||||
class LabelEncoder(nn.Module):
|
||||
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
|
||||
super(LabelEncoder, self).__init__()
|
||||
self.model = []
|
||||
self.model_last = [nn.ReLU()]
|
||||
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
dim *= 2
|
||||
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
|
||||
dim *= 2
|
||||
for i in range(n_downsample - 3):
|
||||
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
||||
self.model_last += [ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation='none', pad_type=pad_type)]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self.model_last = nn.Sequential(*self.model_last)
|
||||
self.output_dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.model(x)
|
||||
return fea, self.model_last(fea)
|
||||
|
||||
# Define the basic block
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, input_dim ,output_dim, kernel_size, stride,
|
||||
padding=0, norm='none', activation='relu', pad_type='zero'):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.use_bias = True
|
||||
# initialize padding
|
||||
if pad_type == 'reflect':
|
||||
self.pad = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == 'replicate':
|
||||
self.pad = nn.ReplicationPad2d(padding)
|
||||
elif pad_type == 'zero':
|
||||
self.pad = nn.ZeroPad2d(padding)
|
||||
else:
|
||||
assert 0, "Unsupported padding type: {}".format(pad_type)
|
||||
|
||||
# initialize normalization
|
||||
norm_dim = output_dim
|
||||
if norm == 'bn':
|
||||
self.norm = nn.BatchNorm2d(norm_dim)
|
||||
elif norm == 'in':
|
||||
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
||||
self.norm = nn.InstanceNorm2d(norm_dim)
|
||||
elif norm == 'ln':
|
||||
self.norm = LayerNorm(norm_dim)
|
||||
elif norm == 'adain':
|
||||
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
||||
elif norm == 'none' or norm == 'sn':
|
||||
self.norm = None
|
||||
else:
|
||||
assert 0, "Unsupported normalization: {}".format(norm)
|
||||
|
||||
# initialize activation
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
elif activation == 'lrelu':
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif activation == 'prelu':
|
||||
self.activation = nn.PReLU()
|
||||
elif activation == 'selu':
|
||||
self.activation = nn.SELU(inplace=True)
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif activation == 'none':
|
||||
self.activation = None
|
||||
else:
|
||||
assert 0, "Unsupported activation: {}".format(activation)
|
||||
|
||||
# initialize convolution
|
||||
if norm == 'sn':
|
||||
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
|
||||
else:
|
||||
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(self.pad(x))
|
||||
if self.norm:
|
||||
x = self.norm(x)
|
||||
if self.activation:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
||||
super(LinearBlock, self).__init__()
|
||||
use_bias = True
|
||||
# initialize fully connected layer
|
||||
if norm == 'sn':
|
||||
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
||||
else:
|
||||
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
||||
|
||||
# initialize normalization
|
||||
norm_dim = output_dim
|
||||
if norm == 'bn':
|
||||
self.norm = nn.BatchNorm1d(norm_dim)
|
||||
elif norm == 'in':
|
||||
self.norm = nn.InstanceNorm1d(norm_dim)
|
||||
elif norm == 'ln':
|
||||
self.norm = LayerNorm(norm_dim)
|
||||
elif norm == 'none' or norm == 'sn':
|
||||
self.norm = None
|
||||
else:
|
||||
assert 0, "Unsupported normalization: {}".format(norm)
|
||||
|
||||
# initialize activation
|
||||
if activation == 'relu':
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
elif activation == 'lrelu':
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
elif activation == 'prelu':
|
||||
self.activation = nn.PReLU()
|
||||
elif activation == 'selu':
|
||||
self.activation = nn.SELU(inplace=True)
|
||||
elif activation == 'tanh':
|
||||
self.activation = nn.Tanh()
|
||||
elif activation == 'none':
|
||||
self.activation = None
|
||||
else:
|
||||
assert 0, "Unsupported activation: {}".format(activation)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fc(x)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
if self.activation:
|
||||
out = self.activation(out)
|
||||
return out
|
||||
|
||||
# Define a resnet block
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.conv_block = self.build_conv_block(dim, norm_type, padding_type, use_dropout)
|
||||
|
||||
def build_conv_block(self, dim, norm_type, padding_type, use_dropout):
|
||||
conv_block = []
|
||||
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='relu', pad_type=padding_type)]
|
||||
conv_block += [ConvBlock(dim ,dim, 3, 1, 1, norm=norm_type, activation='none', pad_type=padding_type)]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
class SFTLayer(nn.Module):
|
||||
def __init__(self):
|
||||
super(SFTLayer, self).__init__()
|
||||
self.SFT_scale_conv1 = nn.Conv2d(64, 64, 1)
|
||||
self.SFT_scale_conv2 = nn.Conv2d(64, 64, 1)
|
||||
self.SFT_shift_conv1 = nn.Conv2d(64, 64, 1)
|
||||
self.SFT_shift_conv2 = nn.Conv2d(64, 64, 1)
|
||||
|
||||
def forward(self, x):
|
||||
scale = self.SFT_scale_conv2(F.leaky_relu(self.SFT_scale_conv1(x[1]), 0.1, inplace=True))
|
||||
shift = self.SFT_shift_conv2(F.leaky_relu(self.SFT_shift_conv1(x[1]), 0.1, inplace=True))
|
||||
return x[0] * scale + shift
|
||||
|
||||
class ConvBlock_SFT(nn.Module):
|
||||
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
|
||||
super(ResnetBlock_SFT, self).__init__()
|
||||
self.sft1 = SFTLayer()
|
||||
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.sft1((x[0], x[1]))
|
||||
fea = F.relu(self.conv1(fea), inplace=True)
|
||||
return (x[0] + fea, x[1])
|
||||
|
||||
class ConvBlock_SFT_last(nn.Module):
|
||||
def __init__(self, dim, norm_type, padding_type, use_dropout=False):
|
||||
super(ResnetBlock_SFT_last, self).__init__()
|
||||
self.sft1 = SFTLayer()
|
||||
self.conv1 = ConvBlock(dim ,dim, 4, 2, 1, norm=norm_type, activation='none', pad_type=padding_type)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.sft1((x[0], x[1]))
|
||||
fea = F.relu(self.conv1(fea), inplace=True)
|
||||
return x[0] + fea
|
||||
|
||||
# Definition of normalization layer
|
||||
class AdaptiveInstanceNorm2d(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super(AdaptiveInstanceNorm2d, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
# weight and bias are dynamically assigned
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
# just dummy buffers, not used
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
|
||||
def forward(self, x):
|
||||
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
|
||||
b, c = x.size(0), x.size(1)
|
||||
running_mean = self.running_mean.repeat(b)
|
||||
running_var = self.running_var.repeat(b)
|
||||
|
||||
# Apply instance norm
|
||||
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
||||
|
||||
out = F.batch_norm(
|
||||
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
||||
True, self.momentum, self.eps)
|
||||
|
||||
return out.view(b, c, *x.size()[2:])
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, affine=True):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.affine = affine
|
||||
self.eps = eps
|
||||
|
||||
if self.affine:
|
||||
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
||||
self.beta = nn.Parameter(torch.zeros(num_features))
|
||||
|
||||
def forward(self, x):
|
||||
shape = [-1] + [1] * (x.dim() - 1)
|
||||
# print(x.size())
|
||||
if x.size(0) == 1:
|
||||
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
||||
mean = x.view(-1).mean().view(*shape)
|
||||
std = x.view(-1).std().view(*shape)
|
||||
else:
|
||||
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
||||
std = x.view(x.size(0), -1).std(1).view(*shape)
|
||||
|
||||
x = (x - mean) / (std + self.eps)
|
||||
|
||||
if self.affine:
|
||||
shape = [1, -1] + [1] * (x.dim() - 2)
|
||||
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
||||
return x
|
||||
|
||||
def l2normalize(v, eps=1e-12):
|
||||
return v / (v.norm() + eps)
|
||||
|
||||
class SpectralNorm(nn.Module):
|
||||
"""
|
||||
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
|
||||
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
||||
"""
|
||||
def __init__(self, module, name='weight', power_iterations=1):
|
||||
super(SpectralNorm, self).__init__()
|
||||
self.module = module
|
||||
self.name = name
|
||||
self.power_iterations = power_iterations
|
||||
if not self._made_params():
|
||||
self._make_params()
|
||||
|
||||
def _update_u_v(self):
|
||||
u = getattr(self.module, self.name + "_u")
|
||||
v = getattr(self.module, self.name + "_v")
|
||||
w = getattr(self.module, self.name + "_bar")
|
||||
|
||||
height = w.data.shape[0]
|
||||
for _ in range(self.power_iterations):
|
||||
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
||||
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
||||
|
||||
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
||||
sigma = u.dot(w.view(height, -1).mv(v))
|
||||
setattr(self.module, self.name, w / sigma.expand_as(w))
|
||||
|
||||
def _made_params(self):
|
||||
try:
|
||||
u = getattr(self.module, self.name + "_u")
|
||||
v = getattr(self.module, self.name + "_v")
|
||||
w = getattr(self.module, self.name + "_bar")
|
||||
return True
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
def _make_params(self):
|
||||
w = getattr(self.module, self.name)
|
||||
|
||||
height = w.data.shape[0]
|
||||
width = w.view(height, -1).data.shape[1]
|
||||
|
||||
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
||||
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
||||
u.data = l2normalize(u.data)
|
||||
v.data = l2normalize(v.data)
|
||||
w_bar = nn.Parameter(w.data)
|
||||
|
||||
del self.module._parameters[self.name]
|
||||
|
||||
self.module.register_parameter(self.name + "_u", u)
|
||||
self.module.register_parameter(self.name + "_v", v)
|
||||
self.module.register_parameter(self.name + "_bar", w_bar)
|
||||
|
||||
def forward(self, *args):
|
||||
self._update_u_v()
|
||||
return self.module.forward(*args)
|
Binary file not shown.
@ -0,0 +1,326 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from torch.autograd import Variable
|
||||
from util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
|
||||
def generate_discrete_label(inputs, label_nc):
|
||||
pred_batch = []
|
||||
size = inputs.size()
|
||||
for input in inputs:
|
||||
input = input.view(1, label_nc, size[2], size[3])
|
||||
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
|
||||
pred_batch.append(pred)
|
||||
|
||||
pred_batch = np.array(pred_batch)
|
||||
pred_batch = torch.from_numpy(pred_batch)
|
||||
label_map = []
|
||||
for p in pred_batch:
|
||||
p = p.view(1, 512, 512)
|
||||
label_map.append(p)
|
||||
label_map = torch.stack(label_map, 0)
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], label_nc, size[2], size[3])
|
||||
if torch.cuda.is_available():
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
else:
|
||||
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
|
||||
|
||||
return input_label
|
||||
|
||||
class Pix2PixHDModel(BaseModel):
|
||||
def name(self):
|
||||
return 'Pix2PixHDModel'
|
||||
|
||||
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
|
||||
flags = (True, use_gan_feat_loss, use_vgg_loss, True, use_gan_feat_loss, use_vgg_loss, True, True, True, True)
|
||||
def loss_filter(g_gan, g_gan_feat, g_vgg, gb_gan, gb_gan_feat, gb_vgg, d_real, d_fake, d_blend):
|
||||
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,gb_gan,gb_gan_feat,gb_vgg,d_real,d_fake,d_blend),flags) if f]
|
||||
return loss_filter
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self.isTrain = opt.isTrain
|
||||
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
|
||||
|
||||
##### define networks
|
||||
# Generator network
|
||||
netG_input_nc = input_nc
|
||||
# Main Generator
|
||||
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
|
||||
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
|
||||
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
|
||||
|
||||
# Discriminator network
|
||||
if self.isTrain:
|
||||
use_sigmoid = opt.no_lsgan
|
||||
netD_input_nc = input_nc + opt.output_nc
|
||||
netB_input_nc = opt.output_nc * 2
|
||||
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
|
||||
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
|
||||
self.netB = networks.define_B(netB_input_nc, opt.output_nc, 32, 3, 3, opt.norm, gpu_ids=self.gpu_ids)
|
||||
|
||||
if self.opt.verbose:
|
||||
print('---------- Networks initialized -------------')
|
||||
|
||||
# load networks
|
||||
if not self.isTrain or opt.continue_train or opt.load_pretrain:
|
||||
pretrained_path = '' if not self.isTrain else opt.load_pretrain
|
||||
print (pretrained_path)
|
||||
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
|
||||
if self.isTrain:
|
||||
self.load_network(self.netB, 'B', opt.which_epoch, pretrained_path)
|
||||
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
|
||||
|
||||
# set loss functions and optimizers
|
||||
if self.isTrain:
|
||||
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
|
||||
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
|
||||
self.fake_pool = ImagePool(opt.pool_size)
|
||||
self.old_lr = opt.lr
|
||||
|
||||
# define loss functions
|
||||
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
|
||||
|
||||
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
|
||||
self.criterionFeat = torch.nn.L1Loss()
|
||||
if not opt.no_vgg_loss:
|
||||
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
|
||||
|
||||
# Names so we can breakout loss
|
||||
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','GB_GAN','GB_GAN_Feat','GB_VGG','D_real','D_fake','D_blend')
|
||||
# initialize optimizers
|
||||
# optimizer G
|
||||
if opt.niter_fix_global > 0:
|
||||
import sys
|
||||
if sys.version_info >= (3,0):
|
||||
finetune_list = set()
|
||||
else:
|
||||
from sets import Set
|
||||
finetune_list = Set()
|
||||
|
||||
params_dict = dict(self.netG.named_parameters())
|
||||
params = []
|
||||
for key, value in params_dict.items():
|
||||
if key.startswith('model' + str(opt.n_local_enhancers)):
|
||||
params += [value]
|
||||
finetune_list.add(key.split('.')[0])
|
||||
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
|
||||
print('The layers that are finetuned are ', sorted(finetune_list))
|
||||
else:
|
||||
params = list(self.netG.parameters())
|
||||
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
# optimizer D
|
||||
params = list(self.netD.parameters())
|
||||
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
# optimizer G + B
|
||||
params = list(self.netG.parameters()) + list(self.netB.parameters())
|
||||
self.optimizer_GB = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
def encode_input(self, inter_label_map_1, label_map, inter_label_map_2, real_image, label_map_ref, real_image_ref, infer=False):
|
||||
|
||||
if self.opt.label_nc == 0:
|
||||
if torch.cuda.is_available():
|
||||
input_label = label_map.data.cuda()
|
||||
inter_label_1 = inter_label_map_1.data.cuda()
|
||||
inter_label_2 = inter_label_map_2.data.cuda()
|
||||
input_label_ref = label_map_ref.data.cuda()
|
||||
else:
|
||||
input_label = label_map.data
|
||||
inter_label_1 = inter_label_map_1.data
|
||||
inter_label_2 = inter_label_map_2.data
|
||||
input_label_ref = label_map_ref.data
|
||||
|
||||
else:
|
||||
# create one-hot vector for label map
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
|
||||
if torch.cuda.is_available():
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
inter_label_1 = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
inter_label_1 = inter_label_1.scatter_(1, inter_label_map_1.data.long().cuda(), 1.0)
|
||||
inter_label_2 = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
inter_label_2 = inter_label_2.scatter_(1, inter_label_map_2.data.long().cuda(), 1.0)
|
||||
input_label_ref = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long().cuda(), 1.0)
|
||||
else:
|
||||
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
|
||||
inter_label_1 = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
inter_label_1 = inter_label_1.scatter_(1, inter_label_map_1.data.long(), 1.0)
|
||||
inter_label_2 = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
inter_label_2 = inter_label_2.scatter_(1, inter_label_map_2.data.long(), 1.0)
|
||||
input_label_ref = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long(), 1.0)
|
||||
|
||||
if self.opt.data_type == 16:
|
||||
input_label = input_label.half()
|
||||
inter_label_1 = inter_label_1.half()
|
||||
inter_label_2 = inter_label_2.half()
|
||||
input_label_ref = input_label_ref.half()
|
||||
|
||||
input_label = Variable(input_label, volatile=infer)
|
||||
inter_label_1 = Variable(inter_label_1, volatile=infer)
|
||||
inter_label_2 = Variable(inter_label_2, volatile=infer)
|
||||
input_label_ref = Variable(input_label_ref, volatile=infer)
|
||||
if torch.cuda.is_available():
|
||||
real_image = Variable(real_image.data.cuda())
|
||||
real_image_ref = Variable(real_image_ref.data.cuda())
|
||||
else:
|
||||
real_image = Variable(real_image.data)
|
||||
real_image_ref = Variable(real_image_ref.data)
|
||||
|
||||
return inter_label_1, input_label, inter_label_2, real_image, input_label_ref, real_image_ref
|
||||
|
||||
def encode_input_test(self, label_map, label_map_ref, real_image_ref, infer=False):
|
||||
|
||||
if self.opt.label_nc == 0:
|
||||
if torch.cuda.is_available():
|
||||
input_label = label_map.data.cuda()
|
||||
input_label_ref = label_map_ref.data.cuda()
|
||||
else:
|
||||
input_label = label_map.data
|
||||
input_label_ref = label_map_ref.data
|
||||
|
||||
else:
|
||||
# create one-hot vector for label map
|
||||
size = label_map.size()
|
||||
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
|
||||
if torch.cuda.is_available():
|
||||
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
|
||||
input_label_ref = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long().cuda(), 1.0)
|
||||
real_image_ref = Variable(real_image_ref.data.cuda())
|
||||
|
||||
else:
|
||||
input_label = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label = input_label.scatter_(1, label_map.data.long(), 1.0)
|
||||
input_label_ref = torch.FloatTensor(torch.Size(oneHot_size)).zero_()
|
||||
input_label_ref = input_label_ref.scatter_(1, label_map_ref.data.long(), 1.0)
|
||||
real_image_ref = Variable(real_image_ref.data)
|
||||
|
||||
|
||||
if self.opt.data_type == 16:
|
||||
input_label = input_label.half()
|
||||
input_label_ref = input_label_ref.half()
|
||||
|
||||
input_label = Variable(input_label, volatile=infer)
|
||||
input_label_ref = Variable(input_label_ref, volatile=infer)
|
||||
|
||||
return input_label, input_label_ref, real_image_ref
|
||||
|
||||
def discriminate(self, input_label, test_image, use_pool=False):
|
||||
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
|
||||
if use_pool:
|
||||
fake_query = self.fake_pool.query(input_concat)
|
||||
return self.netD.forward(fake_query)
|
||||
else:
|
||||
return self.netD.forward(input_concat)
|
||||
|
||||
def forward(self, inter_label_1, label, inter_label_2, image, label_ref, image_ref, infer=False):
|
||||
|
||||
# Encode Inputs
|
||||
inter_label_1, input_label, inter_label_2, real_image, input_label_ref, real_image_ref = self.encode_input(inter_label_1, label, inter_label_2, image, label_ref, image_ref)
|
||||
|
||||
fake_inter_1 = self.netG.forward(inter_label_1, input_label, real_image)
|
||||
fake_image = self.netG.forward(input_label, input_label, real_image)
|
||||
fake_inter_2 = self.netG.forward(inter_label_2, input_label, real_image)
|
||||
|
||||
blend_image, alpha = self.netB.forward(fake_inter_1, fake_inter_2)
|
||||
|
||||
# Fake Detection and Loss
|
||||
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
|
||||
pred_blend_pool = self.discriminate(input_label, blend_image, use_pool=True)
|
||||
loss_D_blend = self.criterionGAN(pred_blend_pool, False)
|
||||
|
||||
# Real Detection and Loss
|
||||
pred_real = self.discriminate(input_label, real_image)
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
|
||||
# GAN loss (Fake Passability Loss)
|
||||
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
|
||||
loss_G_GAN = self.criterionGAN(pred_fake, True)
|
||||
pred_blend = self.netD.forward(torch.cat((input_label, blend_image), dim=1))
|
||||
loss_GB_GAN = self.criterionGAN(pred_blend, True)
|
||||
|
||||
# GAN feature matching loss
|
||||
loss_G_GAN_Feat = 0
|
||||
loss_GB_GAN_Feat = 0
|
||||
if not self.opt.no_ganFeat_loss:
|
||||
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
|
||||
D_weights = 1.0 / self.opt.num_D
|
||||
for i in range(self.opt.num_D):
|
||||
for j in range(len(pred_fake[i])-1):
|
||||
loss_G_GAN_Feat += D_weights * feat_weights * \
|
||||
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
|
||||
loss_GB_GAN_Feat += D_weights * feat_weights * \
|
||||
self.criterionFeat(pred_blend[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
|
||||
|
||||
# VGG feature matching loss
|
||||
loss_G_VGG = 0
|
||||
loss_GB_VGG = 0
|
||||
if not self.opt.no_vgg_loss:
|
||||
loss_G_VGG += self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
|
||||
loss_GB_VGG += self.criterionVGG(blend_image, real_image) * self.opt.lambda_feat
|
||||
|
||||
# Only return the fake_B image if necessary to save BW
|
||||
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_GB_GAN, loss_GB_GAN_Feat, loss_GB_VGG, loss_D_real, loss_D_fake, loss_D_blend ), None if not infer else fake_inter_1, fake_image, fake_inter_2, blend_image, alpha, real_image, inter_label_1, input_label, inter_label_2 ]
|
||||
|
||||
def inference(self, label, label_ref, image_ref):
|
||||
|
||||
# Encode Inputs
|
||||
image_ref = Variable(image_ref)
|
||||
input_label, input_label_ref, real_image_ref = self.encode_input_test(Variable(label), Variable(label_ref), image_ref, infer=True)
|
||||
|
||||
if torch.__version__.startswith('0.4'):
|
||||
with torch.no_grad():
|
||||
fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref)
|
||||
else:
|
||||
fake_image = self.netG.forward(input_label, input_label_ref, real_image_ref)
|
||||
return fake_image
|
||||
|
||||
def save(self, which_epoch):
|
||||
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
|
||||
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
|
||||
self.save_network(self.netB, 'B', which_epoch, self.gpu_ids)
|
||||
|
||||
def update_fixed_params(self):
|
||||
# after fixing the global generator for a number of iterations, also start finetuning it
|
||||
params = list(self.netG.parameters())
|
||||
if self.gen_features:
|
||||
params += list(self.netE.parameters())
|
||||
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
|
||||
if self.opt.verbose:
|
||||
print('------------ Now also finetuning global generator -----------')
|
||||
|
||||
def update_learning_rate(self):
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
lr = self.old_lr - lrd
|
||||
for param_group in self.optimizer_D.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = lr
|
||||
if self.opt.verbose:
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
||||
self.old_lr = lr
|
||||
|
||||
class InferenceModel(Pix2PixHDModel):
|
||||
def forward(self, inp):
|
||||
label = inp
|
||||
return self.inference(label)
|
||||
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,89 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
import argparse
|
||||
import os
|
||||
from util import util
|
||||
import torch
|
||||
|
||||
class BaseOptions():
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser()
|
||||
self.initialized = False
|
||||
|
||||
def initialize(self):
|
||||
# experiment specifics
|
||||
self.parser.add_argument('--name', type=str, default='label2face_512p', help='name of the experiment. It decides where to store samples and models')
|
||||
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
|
||||
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
||||
self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
|
||||
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
|
||||
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
|
||||
self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
|
||||
self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
|
||||
|
||||
# input/output sizes
|
||||
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
|
||||
self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
|
||||
self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
|
||||
self.parser.add_argument('--label_nc', type=int, default=19, help='# of input label channels')
|
||||
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
|
||||
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
|
||||
|
||||
# for setting inputs
|
||||
self.parser.add_argument('--dataroot', type=str, default='../Data_preprocessing/')
|
||||
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
|
||||
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
|
||||
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
|
||||
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
|
||||
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
|
||||
|
||||
# for displays
|
||||
self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
|
||||
self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
|
||||
|
||||
# for generator
|
||||
self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
|
||||
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
|
||||
self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
|
||||
self.parser.add_argument('--n_blocks_global', type=int, default=4, help='number of residual blocks in the global generator network')
|
||||
self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
|
||||
self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
|
||||
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def parse(self, save=True):
|
||||
if not self.initialized:
|
||||
self.initialize()
|
||||
self.opt = self.parser.parse_args()
|
||||
self.opt.isTrain = self.isTrain # train or test
|
||||
|
||||
str_ids = self.opt.gpu_ids.split(',')
|
||||
self.opt.gpu_ids = []
|
||||
for str_id in str_ids:
|
||||
id = int(str_id)
|
||||
if id >= 0:
|
||||
self.opt.gpu_ids.append(id)
|
||||
|
||||
# set gpu ids
|
||||
# if len(self.opt.gpu_ids) > 0:
|
||||
# torch.cuda.set_device(self.opt.gpu_ids[0])
|
||||
|
||||
args = vars(self.opt)
|
||||
|
||||
print('------------ Options -------------')
|
||||
for k, v in sorted(args.items()):
|
||||
print('%s: %s' % (str(k), str(v)))
|
||||
print('-------------- End ----------------')
|
||||
|
||||
# save to the disk
|
||||
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
|
||||
util.mkdirs(expr_dir)
|
||||
if save and not self.opt.continue_train:
|
||||
file_name = os.path.join(expr_dir, 'opt.txt')
|
||||
with open(file_name, 'wt') as opt_file:
|
||||
opt_file.write('------------ Options -------------\n')
|
||||
for k, v in sorted(args.items()):
|
||||
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
||||
opt_file.write('-------------- End ----------------\n')
|
||||
return self.opt
|
Binary file not shown.
@ -0,0 +1,19 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
from .base_options import BaseOptions
|
||||
|
||||
class TestOptions(BaseOptions):
|
||||
def initialize(self):
|
||||
BaseOptions.initialize(self)
|
||||
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
|
||||
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
|
||||
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
|
||||
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
|
||||
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
self.parser.add_argument('--how_many', type=int, default=1000, help='how many test images to run')
|
||||
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
|
||||
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
|
||||
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
|
||||
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
|
||||
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
|
||||
self.isTrain = False
|
Binary file not shown.
@ -0,0 +1,36 @@
|
||||
### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
|
||||
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
||||
from .base_options import BaseOptions
|
||||
|
||||
class TrainOptions(BaseOptions):
|
||||
def initialize(self):
|
||||
BaseOptions.initialize(self)
|
||||
# for displays
|
||||
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
|
||||
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
|
||||
self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
|
||||
self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
|
||||
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
|
||||
self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
|
||||
|
||||
# for training
|
||||
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
|
||||
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/label2face_512p', help='load the pretrained model from the specified location')
|
||||
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
|
||||
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
|
||||
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
|
||||
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
|
||||
self.parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate for adam')
|
||||
|
||||
# for discriminators
|
||||
self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
|
||||
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
|
||||
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
|
||||
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
|
||||
self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
|
||||
self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
|
||||
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
|
||||
self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
|
||||
|
||||
self.isTrain = True
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,31 @@
|
||||
import random
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
class ImagePool():
|
||||
def __init__(self, pool_size):
|
||||
self.pool_size = pool_size
|
||||
if self.pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, images):
|
||||
if self.pool_size == 0:
|
||||
return images
|
||||
return_images = []
|
||||
for image in images.data:
|
||||
image = torch.unsqueeze(image, 0)
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
p = random.uniform(0, 1)
|
||||
if p > 0.5:
|
||||
random_id = random.randint(0, self.pool_size-1)
|
||||
tmp = self.images[random_id].clone()
|
||||
self.images[random_id] = image
|
||||
return_images.append(tmp)
|
||||
else:
|
||||
return_images.append(image)
|
||||
return_images = Variable(torch.cat(return_images, 0))
|
||||
return return_images
|
Binary file not shown.
@ -0,0 +1,107 @@
|
||||
from __future__ import print_function
|
||||
|
||||
print ('?')
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
# import numpy as np
|
||||
import os
|
||||
|
||||
# Converts a Tensor into a Numpy array
|
||||
# |imtype|: the desired type of the converted numpy array
|
||||
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
||||
if isinstance(image_tensor, list):
|
||||
image_numpy = []
|
||||
for i in range(len(image_tensor)):
|
||||
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
||||
return image_numpy
|
||||
image_numpy = image_tensor.cpu().float().numpy()
|
||||
#if normalize:
|
||||
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
#else:
|
||||
# image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
||||
image_numpy = (image_numpy + 1) / 2.0
|
||||
image_numpy = np.clip(image_numpy, 0, 1)
|
||||
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
|
||||
image_numpy = image_numpy[:,:,0]
|
||||
|
||||
return image_numpy
|
||||
|
||||
# Converts a one-hot tensor into a colorful label map
|
||||
def tensor2label(label_tensor, n_label, imtype=np.uint8):
|
||||
if n_label == 0:
|
||||
return tensor2im(label_tensor, imtype)
|
||||
label_tensor = label_tensor.cpu().float()
|
||||
if label_tensor.size()[0] > 1:
|
||||
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
||||
label_tensor = Colorize(n_label)(label_tensor)
|
||||
#label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
||||
label_numpy = label_tensor.numpy()
|
||||
label_numpy = label_numpy / 255.0
|
||||
|
||||
return label_numpy
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
###############################################################################
|
||||
# Code from
|
||||
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
|
||||
# Modified so it complies with the Citscape label map colors
|
||||
###############################################################################
|
||||
def uint82bin(n, count=8):
|
||||
"""returns the binary of integer n, count refers to amount of bits"""
|
||||
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
|
||||
|
||||
def labelcolormap(N):
|
||||
if N == 35: # cityscape
|
||||
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
|
||||
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
|
||||
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
|
||||
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
|
||||
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
|
||||
dtype=np.uint8)
|
||||
else:
|
||||
cmap = np.zeros((N, 3), dtype=np.uint8)
|
||||
for i in range(N):
|
||||
r, g, b = 0, 0, 0
|
||||
id = i
|
||||
for j in range(7):
|
||||
str_id = uint82bin(id)
|
||||
r = r ^ (np.uint8(str_id[-1]) << (7-j))
|
||||
g = g ^ (np.uint8(str_id[-2]) << (7-j))
|
||||
b = b ^ (np.uint8(str_id[-3]) << (7-j))
|
||||
id = id >> 3
|
||||
cmap[i, 0] = r
|
||||
cmap[i, 1] = g
|
||||
cmap[i, 2] = b
|
||||
return cmap
|
||||
|
||||
class Colorize(object):
|
||||
def __init__(self, n=35):
|
||||
self.cmap = labelcolormap(n)
|
||||
self.cmap = torch.from_numpy(self.cmap[:n])
|
||||
|
||||
def __call__(self, gray_image):
|
||||
size = gray_image.size()
|
||||
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
|
||||
|
||||
for label in range(0, len(self.cmap)):
|
||||
mask = (label == gray_image[0]).cpu()
|
||||
color_image[0][mask] = self.cmap[label][0]
|
||||
color_image[1][mask] = self.cmap[label][1]
|
||||
color_image[2][mask] = self.cmap[label][2]
|
||||
|
||||
return color_image
|
Binary file not shown.
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import copy
|
||||
|
||||
|
||||
class GANFactory:
|
||||
factories = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def add_factory(gan_id, model_factory):
|
||||
GANFactory.factories.put[gan_id] = model_factory
|
||||
|
||||
add_factory = staticmethod(add_factory)
|
||||
|
||||
# A Template Method:
|
||||
|
||||
def create_model(gan_id, net_d=None, criterion=None):
|
||||
if gan_id not in GANFactory.factories:
|
||||
GANFactory.factories[gan_id] = \
|
||||
eval(gan_id + '.Factory()')
|
||||
return GANFactory.factories[gan_id].create(net_d, criterion)
|
||||
|
||||
create_model = staticmethod(create_model)
|
||||
|
||||
|
||||
class GANTrainer(object):
|
||||
def __init__(self, net_d, criterion):
|
||||
self.net_d = net_d
|
||||
self.criterion = criterion
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
pass
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
pass
|
||||
|
||||
def get_params(self):
|
||||
pass
|
||||
|
||||
|
||||
class NoGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return [0]
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return 0
|
||||
|
||||
def get_params(self):
|
||||
return [torch.nn.Parameter(torch.Tensor(1))]
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return NoGAN(net_d, criterion)
|
||||
|
||||
|
||||
class SingleGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
self.net_d = self.net_d.cuda()
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return self.criterion(self.net_d, pred, gt)
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return self.criterion.get_g_loss(self.net_d, pred, gt)
|
||||
|
||||
def get_params(self):
|
||||
return self.net_d.parameters()
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return SingleGAN(net_d, criterion)
|
||||
|
||||
|
||||
class DoubleGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
self.patch_d = net_d['patch'].cuda()
|
||||
self.full_d = net_d['full'].cuda()
|
||||
self.full_criterion = copy.deepcopy(criterion)
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred,
|
||||
gt)) / 2
|
||||
|
||||
def get_params(self):
|
||||
return list(self.patch_d.parameters()) + list(self.full_d.parameters())
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return DoubleGAN(net_d, criterion)
|
||||
|
@ -0,0 +1,93 @@
|
||||
from typing import List
|
||||
|
||||
import albumentations as albu
|
||||
|
||||
|
||||
def get_transforms(size, scope = 'geometric', crop='random'):
|
||||
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
||||
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
||||
albu.ElasticTransform(),
|
||||
albu.OpticalDistortion(),
|
||||
albu.OneOf([
|
||||
albu.CLAHE(clip_limit=2),
|
||||
albu.IAASharpen(),
|
||||
albu.IAAEmboss(),
|
||||
albu.RandomBrightnessContrast(),
|
||||
albu.RandomGamma()
|
||||
], p=0.5),
|
||||
albu.OneOf([
|
||||
albu.RGBShift(),
|
||||
albu.HueSaturationValue(),
|
||||
], p=0.5),
|
||||
]),
|
||||
'weak': albu.Compose([albu.HorizontalFlip(),
|
||||
]),
|
||||
'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True),
|
||||
albu.ShiftScaleRotate(always_apply=True),
|
||||
albu.Transpose(always_apply=True),
|
||||
albu.OpticalDistortion(always_apply=True),
|
||||
albu.ElasticTransform(always_apply=True),
|
||||
])
|
||||
}
|
||||
|
||||
aug_fn = augs[scope]
|
||||
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
||||
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
||||
pad = albu.PadIfNeeded(size, size)
|
||||
|
||||
pipeline = albu.Compose([aug_fn, crop_fn, pad], additional_targets={'target': 'image'})
|
||||
|
||||
def process(a, b):
|
||||
r = pipeline(image=a, target=b)
|
||||
return r['image'], r['target']
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def get_normalize():
|
||||
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
||||
|
||||
def process(a, b):
|
||||
r = normalize(image=a, target=b)
|
||||
return r['image'], r['target']
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def _resolve_aug_fn(name):
|
||||
d = {
|
||||
'cutout': albu.Cutout,
|
||||
'rgb_shift': albu.RGBShift,
|
||||
'hsv_shift': albu.HueSaturationValue,
|
||||
'motion_blur': albu.MotionBlur,
|
||||
'median_blur': albu.MedianBlur,
|
||||
'snow': albu.RandomSnow,
|
||||
'shadow': albu.RandomShadow,
|
||||
'fog': albu.RandomFog,
|
||||
'brightness_contrast': albu.RandomBrightnessContrast,
|
||||
'gamma': albu.RandomGamma,
|
||||
'sun_flare': albu.RandomSunFlare,
|
||||
'sharpen': albu.IAASharpen,
|
||||
'jpeg': albu.JpegCompression,
|
||||
'gray': albu.ToGray,
|
||||
# ToDo: pixelize
|
||||
# ToDo: partial gray
|
||||
}
|
||||
return d[name]
|
||||
|
||||
|
||||
def get_corrupt_function(config):
|
||||
augs = []
|
||||
for aug_params in config:
|
||||
name = aug_params.pop('name')
|
||||
cls = _resolve_aug_fn(name)
|
||||
prob = aug_params.pop('prob') if 'prob' in aug_params else .5
|
||||
augs.append(cls(p=prob, **aug_params))
|
||||
|
||||
augs = albu.OneOf(augs)
|
||||
|
||||
def process(x):
|
||||
return augs(image=x)['image']
|
||||
|
||||
return process
|
Binary file not shown.
@ -0,0 +1,68 @@
|
||||
---
|
||||
project: deblur_gan
|
||||
experiment_desc: fpn
|
||||
|
||||
train:
|
||||
files_a: &FILES_A /datasets/my_dataset/**/*.jpg
|
||||
files_b: *FILES_A
|
||||
size: &SIZE 256
|
||||
crop: random
|
||||
preload: &PRELOAD false
|
||||
preload_size: &PRELOAD_SIZE 0
|
||||
bounds: [0, .9]
|
||||
scope: geometric
|
||||
corrupt: &CORRUPT
|
||||
- name: cutout
|
||||
prob: 0.5
|
||||
num_holes: 3
|
||||
max_h_size: 25
|
||||
max_w_size: 25
|
||||
- name: jpeg
|
||||
quality_lower: 70
|
||||
quality_upper: 90
|
||||
- name: motion_blur
|
||||
- name: median_blur
|
||||
- name: gamma
|
||||
- name: rgb_shift
|
||||
- name: hsv_shift
|
||||
- name: sharpen
|
||||
|
||||
val:
|
||||
files_a: *FILES_A
|
||||
files_b: *FILES_A
|
||||
size: *SIZE
|
||||
scope: geometric
|
||||
crop: center
|
||||
preload: *PRELOAD
|
||||
preload_size: *PRELOAD_SIZE
|
||||
bounds: [.9, 1]
|
||||
corrupt: *CORRUPT
|
||||
|
||||
phase: train
|
||||
warmup_num: 3
|
||||
model:
|
||||
g_name: fpn_inception
|
||||
blocks: 9
|
||||
d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
|
||||
d_layers: 3
|
||||
content_loss: perceptual
|
||||
adv_lambda: 0.001
|
||||
disc_loss: wgan-gp
|
||||
learn_residual: True
|
||||
norm_layer: instance
|
||||
dropout: True
|
||||
|
||||
num_epochs: 200
|
||||
train_batches_per_epoch: 1000
|
||||
val_batches_per_epoch: 100
|
||||
batch_size: 1
|
||||
image_size: [256, 256]
|
||||
|
||||
optimizer:
|
||||
name: adam
|
||||
lr: 0.0001
|
||||
scheduler:
|
||||
name: linear
|
||||
start_epoch: 50
|
||||
min_lr: 0.0000001
|
||||
|
@ -0,0 +1,142 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
from hashlib import sha1
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from glog import logger
|
||||
from joblib import Parallel, cpu_count, delayed
|
||||
from skimage.io import imread
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import aug
|
||||
|
||||
|
||||
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
|
||||
data = list(data)
|
||||
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
|
||||
|
||||
lower_bound, upper_bound = [x * n_buckets for x in bounds]
|
||||
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
|
||||
if salt:
|
||||
msg += f'; salt is {salt}'
|
||||
if verbose:
|
||||
logger.info(msg)
|
||||
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
|
||||
|
||||
|
||||
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
|
||||
path_a, path_b = x
|
||||
names = ''.join(map(os.path.basename, (path_a, path_b)))
|
||||
return sha1(f'{names}_{salt}'.encode()).hexdigest()
|
||||
|
||||
|
||||
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
|
||||
hashes = map(partial(hash_fn, salt=salt), data)
|
||||
return np.array([int(x, 16) % n_buckets for x in hashes])
|
||||
|
||||
|
||||
def _read_img(x: str):
|
||||
img = cv2.imread(x)
|
||||
if img is None:
|
||||
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
|
||||
img = imread(x)
|
||||
return img
|
||||
|
||||
|
||||
class PairedDataset(Dataset):
|
||||
def __init__(self,
|
||||
files_a: Tuple[str],
|
||||
files_b: Tuple[str],
|
||||
transform_fn: Callable,
|
||||
normalize_fn: Callable,
|
||||
corrupt_fn: Optional[Callable] = None,
|
||||
preload: bool = True,
|
||||
preload_size: Optional[int] = 0,
|
||||
verbose=True):
|
||||
|
||||
assert len(files_a) == len(files_b)
|
||||
|
||||
self.preload = preload
|
||||
self.data_a = files_a
|
||||
self.data_b = files_b
|
||||
self.verbose = verbose
|
||||
self.corrupt_fn = corrupt_fn
|
||||
self.transform_fn = transform_fn
|
||||
self.normalize_fn = normalize_fn
|
||||
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
|
||||
|
||||
if preload:
|
||||
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
|
||||
if files_a == files_b:
|
||||
self.data_a = self.data_b = preload_fn(self.data_a)
|
||||
else:
|
||||
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
|
||||
self.preload = True
|
||||
|
||||
def _bulk_preload(self, data: Iterable[str], preload_size: int):
|
||||
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
|
||||
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
|
||||
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
|
||||
|
||||
@staticmethod
|
||||
def _preload(x: str, preload_size: int):
|
||||
img = _read_img(x)
|
||||
if preload_size:
|
||||
h, w, *_ = img.shape
|
||||
h_scale = preload_size / h
|
||||
w_scale = preload_size / w
|
||||
scale = max(h_scale, w_scale)
|
||||
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
|
||||
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
|
||||
return img
|
||||
|
||||
def _preprocess(self, img, res):
|
||||
def transpose(x):
|
||||
return np.transpose(x, (2, 0, 1))
|
||||
|
||||
return map(transpose, self.normalize_fn(img, res))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_a)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a, b = self.data_a[idx], self.data_b[idx]
|
||||
if not self.preload:
|
||||
a, b = map(_read_img, (a, b))
|
||||
a, b = self.transform_fn(a, b)
|
||||
if self.corrupt_fn is not None:
|
||||
a = self.corrupt_fn(a)
|
||||
a, b = self._preprocess(a, b)
|
||||
return {'a': a, 'b': b}
|
||||
|
||||
@staticmethod
|
||||
def from_config(config):
|
||||
config = deepcopy(config)
|
||||
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
|
||||
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
|
||||
normalize_fn = aug.get_normalize()
|
||||
corrupt_fn = aug.get_corrupt_function(config['corrupt'])
|
||||
|
||||
hash_fn = hash_from_paths
|
||||
# ToDo: add more hash functions
|
||||
verbose = config.get('verbose', True)
|
||||
data = subsample(data=zip(files_a, files_b),
|
||||
bounds=config.get('bounds', (0, 1)),
|
||||
hash_fn=hash_fn,
|
||||
verbose=verbose)
|
||||
|
||||
files_a, files_b = map(list, zip(*data))
|
||||
|
||||
return PairedDataset(files_a=files_a,
|
||||
files_b=files_b,
|
||||
preload=config['preload'],
|
||||
preload_size=config['preload_size'],
|
||||
corrupt_fn=corrupt_fn,
|
||||
normalize_fn=normalize_fn,
|
||||
transform_fn=transform_fn,
|
||||
verbose=verbose)
|
@ -0,0 +1,56 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
WINDOW_SIZE = 100
|
||||
|
||||
|
||||
class MetricCounter:
|
||||
def __init__(self, exp_name):
|
||||
self.writer = SummaryWriter(exp_name)
|
||||
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
self.best_metric = 0
|
||||
|
||||
def add_image(self, x: np.ndarray, tag: str):
|
||||
self.images[tag].append(x)
|
||||
|
||||
def clear(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
|
||||
def add_losses(self, l_G, l_content, l_D=0):
|
||||
for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
|
||||
(l_G, l_content, l_G - l_content, l_D)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def add_metrics(self, psnr, ssim):
|
||||
for name, value in zip(('PSNR', 'SSIM'),
|
||||
(psnr, ssim)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def loss_message(self):
|
||||
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
|
||||
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
|
||||
|
||||
def write_to_tensorboard(self, epoch_num, validation=False):
|
||||
scalar_prefix = 'Validation' if validation else 'Train'
|
||||
for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'):
|
||||
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
|
||||
for tag in self.images:
|
||||
imgs = self.images[tag]
|
||||
if imgs:
|
||||
imgs = np.array(imgs)
|
||||
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
|
||||
global_step=epoch_num)
|
||||
self.images[tag] = []
|
||||
|
||||
def update_best_model(self):
|
||||
cur_metric = np.mean(self.metrics['PSNR'])
|
||||
if self.best_metric < cur_metric:
|
||||
self.best_metric = cur_metric
|
||||
return True
|
||||
return False
|
Binary file not shown.
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision.models import resnet50, densenet121, densenet201
|
||||
|
||||
|
||||
class FPNSegHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNDense(nn.Module):
|
||||
|
||||
def __init__(self, output_ch=3, num_filters=128, num_filters_fpn=256, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
|
||||
nn.Tanh(final)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, num_filters=256, pretrained=True):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.features = densenet121(pretrained=pretrained).features
|
||||
|
||||
self.enc0 = nn.Sequential(self.features.conv0,
|
||||
self.features.norm0,
|
||||
self.features.relu0)
|
||||
self.pool0 = self.features.pool0
|
||||
self.enc1 = self.features.denseblock1 # 256
|
||||
self.enc2 = self.features.denseblock2 # 512
|
||||
self.enc3 = self.features.denseblock3 # 1024
|
||||
self.enc4 = self.features.denseblock4 # 2048
|
||||
self.norm = self.features.norm5 # 2048
|
||||
|
||||
self.tr1 = self.features.transition1 # 256
|
||||
self.tr2 = self.features.transition2 # 512
|
||||
self.tr3 = self.features.transition3 # 1024
|
||||
|
||||
self.lateral4 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(64, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
pooled = self.pool0(enc0)
|
||||
|
||||
enc1 = self.enc1(pooled) # 256
|
||||
tr1 = self.tr1(enc1)
|
||||
|
||||
enc2 = self.enc2(tr1) # 512
|
||||
tr2 = self.tr2(enc2)
|
||||
|
||||
enc3 = self.enc3(tr2) # 1024
|
||||
tr3 = self.tr3(enc3)
|
||||
|
||||
enc4 = self.enc4(tr3) # 2048
|
||||
enc4 = self.norm(enc4)
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.lateral4(enc4)
|
||||
lateral3 = self.lateral3(enc3)
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.lateral1(enc1)
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
|
||||
map4 = lateral4
|
||||
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
|
||||
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
|
||||
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
|
||||
|
||||
return lateral0, map1, map2, map3, map4
|
Binary file not shown.
@ -0,0 +1,167 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# from pretrainedmodels import inceptionresnetv2
|
||||
# from torchsummary import summary
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super(FPNHead,self).__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, num_in, num_out, norm_layer):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
|
||||
norm_layer(num_out),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPNInception(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
|
||||
super(FPNInception,self).__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min = -1,max = 1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=256):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super(FPN,self).__init__()
|
||||
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
||||
# self.inception = torch.load('inceptionresnetv2-520b38e4.pth')
|
||||
self.enc0 = self.inception.conv2d_1a
|
||||
self.enc1 = nn.Sequential(
|
||||
self.inception.conv2d_2a,
|
||||
self.inception.conv2d_2b,
|
||||
self.inception.maxpool_3a,
|
||||
) # 64
|
||||
self.enc2 = nn.Sequential(
|
||||
self.inception.conv2d_3b,
|
||||
self.inception.conv2d_4a,
|
||||
self.inception.maxpool_5a,
|
||||
) # 192
|
||||
self.enc3 = nn.Sequential(
|
||||
self.inception.mixed_5b,
|
||||
self.inception.repeat,
|
||||
self.inception.mixed_6a,
|
||||
) # 1088
|
||||
self.enc4 = nn.Sequential(
|
||||
self.inception.repeat_1,
|
||||
self.inception.mixed_7a,
|
||||
) #2080
|
||||
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.pad = nn.ReflectionPad2d(1)
|
||||
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.pad(self.lateral4(enc4))
|
||||
lateral3 = self.pad(self.lateral3(enc3))
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.pad(self.lateral1(enc1))
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
|
||||
pad1 = (0, 1, 0, 1)
|
||||
map4 = lateral4
|
||||
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
|
||||
map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
|
||||
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
|
||||
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
|
Binary file not shown.
@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pretrainedmodels import inceptionresnetv2
|
||||
from torchsummary import summary
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, num_in, num_out, norm_layer):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
|
||||
norm_layer(num_out),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPNInceptionSimple(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min = -1,max = 1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=256):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
||||
|
||||
self.enc0 = self.inception.conv2d_1a
|
||||
self.enc1 = nn.Sequential(
|
||||
self.inception.conv2d_2a,
|
||||
self.inception.conv2d_2b,
|
||||
self.inception.maxpool_3a,
|
||||
) # 64
|
||||
self.enc2 = nn.Sequential(
|
||||
self.inception.conv2d_3b,
|
||||
self.inception.conv2d_4a,
|
||||
self.inception.maxpool_5a,
|
||||
) # 192
|
||||
self.enc3 = nn.Sequential(
|
||||
self.inception.mixed_5b,
|
||||
self.inception.repeat,
|
||||
self.inception.mixed_6a,
|
||||
) # 1088
|
||||
self.enc4 = nn.Sequential(
|
||||
self.inception.repeat_1,
|
||||
self.inception.mixed_7a,
|
||||
) #2080
|
||||
|
||||
self.pad = nn.ReflectionPad2d(1)
|
||||
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.pad(self.lateral4(enc4))
|
||||
lateral3 = self.pad(self.lateral3(enc3))
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.pad(self.lateral1(enc1))
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
|
||||
pad1 = (0, 1, 0, 1)
|
||||
map4 = lateral4
|
||||
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
|
||||
map2 = F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
|
||||
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
|
||||
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
|
Binary file not shown.
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mobilenet_v2 import MobileNetV2
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=64, num_filters_fpn=128, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer = norm_layer, pretrained=pretrained)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min=-1, max=1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=128, pretrained=True):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
net = MobileNetV2(n_class=1000)
|
||||
|
||||
if pretrained:
|
||||
#Load weights into the project directory
|
||||
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
|
||||
net.load_state_dict(state_dict)
|
||||
self.features = net.features
|
||||
|
||||
self.enc0 = nn.Sequential(*self.features[0:2])
|
||||
self.enc1 = nn.Sequential(*self.features[2:4])
|
||||
self.enc2 = nn.Sequential(*self.features[4:7])
|
||||
self.enc3 = nn.Sequential(*self.features[7:11])
|
||||
self.enc4 = nn.Sequential(*self.features[11:16])
|
||||
|
||||
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.lateral4(enc4)
|
||||
lateral3 = self.lateral3(enc3)
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.lateral1(enc1)
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
map4 = lateral4
|
||||
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
|
||||
map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
|
||||
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
|
||||
return lateral0, map1, map2, map3, map4
|
||||
|
Binary file not shown.
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.mobilenet_v2 import MobileNetV2
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=128, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min=-1, max=1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, num_filters=128, pretrained=True):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
net = MobileNetV2(n_class=1000)
|
||||
|
||||
if pretrained:
|
||||
#Load weights into the project directory
|
||||
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
|
||||
net.load_state_dict(state_dict)
|
||||
self.features = net.features
|
||||
|
||||
self.enc0 = nn.Sequential(*self.features[0:2])
|
||||
self.enc1 = nn.Sequential(*self.features[2:4])
|
||||
self.enc2 = nn.Sequential(*self.features[4:7])
|
||||
self.enc3 = nn.Sequential(*self.features[7:11])
|
||||
self.enc4 = nn.Sequential(*self.features[11:16])
|
||||
|
||||
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.lateral4(enc4)
|
||||
lateral3 = self.lateral3(enc3)
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.lateral1(enc1)
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
map4 = lateral4
|
||||
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
|
||||
map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
|
||||
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
|
||||
return lateral0, map1, map2, map3, map4
|
||||
|
@ -0,0 +1,300 @@
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
from util.image_pool import ImagePool
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
|
||||
class ContentLoss():
|
||||
def initialize(self, loss):
|
||||
self.criterion = loss
|
||||
|
||||
def get_loss(self, fakeIm, realIm):
|
||||
return self.criterion(fakeIm, realIm)
|
||||
|
||||
def __call__(self, fakeIm, realIm):
|
||||
return self.get_loss(fakeIm, realIm)
|
||||
|
||||
|
||||
class PerceptualLoss():
|
||||
|
||||
def contentFunc(self):
|
||||
conv_3_3_layer = 14
|
||||
cnn = models.vgg19(pretrained=True).features
|
||||
cnn = cnn.cuda()
|
||||
model = nn.Sequential()
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
for i, layer in enumerate(list(cnn)):
|
||||
model.add_module(str(i), layer)
|
||||
if i == conv_3_3_layer:
|
||||
break
|
||||
return model
|
||||
|
||||
def initialize(self, loss):
|
||||
with torch.no_grad():
|
||||
self.criterion = loss
|
||||
self.contentFunc = self.contentFunc()
|
||||
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def get_loss(self, fakeIm, realIm):
|
||||
fakeIm = (fakeIm + 1) / 2.0
|
||||
realIm = (realIm + 1) / 2.0
|
||||
fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
|
||||
realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
|
||||
f_fake = self.contentFunc.forward(fakeIm)
|
||||
f_real = self.contentFunc.forward(realIm)
|
||||
f_real_no_grad = f_real.detach()
|
||||
loss = self.criterion(f_fake, f_real_no_grad)
|
||||
return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
|
||||
|
||||
def __call__(self, fakeIm, realIm):
|
||||
return self.get_loss(fakeIm, realIm)
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_var = None
|
||||
self.fake_label_var = None
|
||||
self.Tensor = tensor
|
||||
if use_l1:
|
||||
self.loss = nn.L1Loss()
|
||||
else:
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
if target_is_real:
|
||||
create_label = ((self.real_label_var is None) or
|
||||
(self.real_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
||||
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
||||
target_tensor = self.real_label_var
|
||||
else:
|
||||
create_label = ((self.fake_label_var is None) or
|
||||
(self.fake_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
||||
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
||||
target_tensor = self.fake_label_var
|
||||
return target_tensor.cuda()
|
||||
|
||||
def __call__(self, input, target_is_real):
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
return self.loss(input, target_tensor)
|
||||
|
||||
|
||||
class DiscLoss(nn.Module):
|
||||
def name(self):
|
||||
return 'DiscLoss'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLoss, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=False)
|
||||
self.fake_AB_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
pred_fake = net.forward(fakeB)
|
||||
return self.criterionGAN(pred_fake, 1)
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.loss_D_real = self.criterionGAN(self.pred_real, 1)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class RelativisticDiscLoss(nn.Module):
|
||||
def name(self):
|
||||
return 'RelativisticDiscLoss'
|
||||
|
||||
def __init__(self):
|
||||
super(RelativisticDiscLoss, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=False)
|
||||
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
|
||||
self.real_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.pred_fake = net.forward(fakeB)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
|
||||
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
|
||||
return errG
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.fake_B = fakeB.detach()
|
||||
self.real_B = realB
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.fake_pool.add(self.pred_fake)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.real_pool.add(self.pred_real)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
|
||||
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class RelativisticDiscLossLS(nn.Module):
|
||||
def name(self):
|
||||
return 'RelativisticDiscLossLS'
|
||||
|
||||
def __init__(self):
|
||||
super(RelativisticDiscLossLS, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=True)
|
||||
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
|
||||
self.real_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.pred_fake = net.forward(fakeB)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
|
||||
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
|
||||
return errG
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.fake_B = fakeB.detach()
|
||||
self.real_B = realB
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.fake_pool.add(self.pred_fake)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.real_pool.add(self.pred_real)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
|
||||
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class DiscLossLS(DiscLoss):
|
||||
def name(self):
|
||||
return 'DiscLossLS'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLossLS, self).__init__()
|
||||
self.criterionGAN = GANLoss(use_l1=True)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
return DiscLoss.get_g_loss(self, net, fakeB)
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
return DiscLoss.get_loss(self, net, fakeB, realB)
|
||||
|
||||
|
||||
class DiscLossWGANGP(DiscLossLS):
|
||||
def name(self):
|
||||
return 'DiscLossWGAN-GP'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLossWGANGP, self).__init__()
|
||||
self.LAMBDA = 10
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.D_fake = net.forward(fakeB)
|
||||
return -self.D_fake.mean()
|
||||
|
||||
def calc_gradient_penalty(self, netD, real_data, fake_data):
|
||||
alpha = torch.rand(1, 1)
|
||||
alpha = alpha.expand(real_data.size())
|
||||
alpha = alpha.cuda()
|
||||
|
||||
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
|
||||
|
||||
interpolates = interpolates.cuda()
|
||||
interpolates = Variable(interpolates, requires_grad=True)
|
||||
|
||||
disc_interpolates = netD.forward(interpolates)
|
||||
|
||||
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
|
||||
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
|
||||
return gradient_penalty
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
self.D_fake = net.forward(fakeB.detach())
|
||||
self.D_fake = self.D_fake.mean()
|
||||
|
||||
# Real
|
||||
self.D_real = net.forward(realB)
|
||||
self.D_real = self.D_real.mean()
|
||||
# Combined loss
|
||||
self.loss_D = self.D_fake - self.D_real
|
||||
gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
|
||||
return self.loss_D + gradient_penalty
|
||||
|
||||
|
||||
def get_loss(model):
|
||||
if model['content_loss'] == 'perceptual':
|
||||
content_loss = PerceptualLoss()
|
||||
content_loss.initialize(nn.MSELoss())
|
||||
elif model['content_loss'] == 'l1':
|
||||
content_loss = ContentLoss()
|
||||
content_loss.initialize(nn.L1Loss())
|
||||
else:
|
||||
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
|
||||
|
||||
if model['disc_loss'] == 'wgan-gp':
|
||||
disc_loss = DiscLossWGANGP()
|
||||
elif model['disc_loss'] == 'lsgan':
|
||||
disc_loss = DiscLossLS()
|
||||
elif model['disc_loss'] == 'gan':
|
||||
disc_loss = DiscLoss()
|
||||
elif model['disc_loss'] == 'ragan':
|
||||
disc_loss = RelativisticDiscLoss()
|
||||
elif model['disc_loss'] == 'ragan-ls':
|
||||
disc_loss = RelativisticDiscLossLS()
|
||||
else:
|
||||
raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
|
||||
return content_loss, disc_loss
|
@ -0,0 +1,126 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = round(inp * expand_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expand_ratio == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
interverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
assert input_size % 32 == 0
|
||||
input_channel = int(input_channel * width_mult)
|
||||
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
||||
self.features = [conv_bn(3, input_channel, 2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in interverted_residual_setting:
|
||||
output_channel = int(c * width_mult)
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
||||
else:
|
||||
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*self.features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, n_class),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.mean(3).mean(2)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
Binary file not shown.
@ -0,0 +1,35 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from skimage.measure import compare_ssim as SSIM
|
||||
|
||||
from util.metrics import PSNR
|
||||
|
||||
|
||||
class DeblurModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeblurModel, self).__init__()
|
||||
|
||||
def get_input(self, data):
|
||||
img = data['a']
|
||||
inputs = img
|
||||
targets = data['b']
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
return inputs, targets
|
||||
|
||||
def tensor2im(self, image_tensor, imtype=np.uint8):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def get_images_and_metrics(self, inp, output, target):
|
||||
inp = self.tensor2im(inp)
|
||||
fake = self.tensor2im(output.data)
|
||||
real = self.tensor2im(target.data)
|
||||
psnr = PSNR(fake, real)
|
||||
ssim = SSIM(fake, real, multichannel=True)
|
||||
vis_img = np.hstack((inp, fake, real))
|
||||
return psnr, ssim, vis_img
|
||||
|
||||
|
||||
def get_model(model_config):
|
||||
return DeblurModel()
|
Binary file not shown.
@ -0,0 +1,330 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import functools
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from fpn_mobilenet import FPNMobileNet
|
||||
from fpn_inception import FPNInception
|
||||
# from fpn_inception_simple import FPNInceptionSimple
|
||||
from unet_seresnext import UNetSEResNext
|
||||
from fpn_densenet import FPNDense
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
return norm_layer
|
||||
|
||||
##############################################################################
|
||||
# Classes
|
||||
##############################################################################
|
||||
|
||||
|
||||
# Defines the generator that consists of Resnet blocks between a few
|
||||
# downsampling/upsampling operations.
|
||||
# Code and idea originally from Justin Johnson's architecture.
|
||||
# https://github.com/jcjohnson/fast-neural-style/
|
||||
class ResnetGenerator(nn.Module):
|
||||
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
|
||||
assert(n_blocks >= 0)
|
||||
super(ResnetGenerator, self).__init__()
|
||||
self.input_nc = input_nc
|
||||
self.output_nc = output_nc
|
||||
self.ngf = ngf
|
||||
self.use_parallel = use_parallel
|
||||
self.learn_residual = learn_residual
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
model = [nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
||||
bias=use_bias),
|
||||
norm_layer(ngf),
|
||||
nn.ReLU(True)]
|
||||
|
||||
n_downsampling = 2
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
||||
stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(ngf * mult * 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
mult = 2**n_downsampling
|
||||
for i in range(n_blocks):
|
||||
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
||||
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
||||
kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1,
|
||||
bias=use_bias),
|
||||
norm_layer(int(ngf * mult / 2)),
|
||||
nn.ReLU(True)]
|
||||
model += [nn.ReflectionPad2d(3)]
|
||||
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
||||
model += [nn.Tanh()]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.model(input)
|
||||
if self.learn_residual:
|
||||
output = input + output
|
||||
output = torch.clamp(output,min = -1,max = 1)
|
||||
return output
|
||||
|
||||
|
||||
# Define a resnet block
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
||||
|
||||
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
|
||||
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
nn.ReLU(True)]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim)]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
|
||||
class DicsriminatorTail(nn.Module):
|
||||
def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||||
super(DicsriminatorTail, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence = [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||||
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult = 1
|
||||
for n in range(1, 3):
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
self.scale_one = nn.Sequential(*sequence)
|
||||
self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
|
||||
nf_mult_prev = 4
|
||||
nf_mult = 8
|
||||
|
||||
self.scale_two = nn.Sequential(
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True))
|
||||
nf_mult_prev = nf_mult
|
||||
self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
|
||||
self.scale_three = nn.Sequential(
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True))
|
||||
self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.scale_one(input)
|
||||
x_1 = self.first_tail(x)
|
||||
x = self.scale_two(x)
|
||||
x_2 = self.second_tail(x)
|
||||
x = self.scale_three(x)
|
||||
x = self.third_tail(x)
|
||||
return [x_1, x_2, x]
|
||||
|
||||
|
||||
# Defines the PatchGAN discriminator with the specified arguments.
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult = 1
|
||||
for n in range(1, n_layers):
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||||
|
||||
if use_sigmoid:
|
||||
sequence += [nn.Sigmoid()]
|
||||
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
def get_fullD(model_config):
|
||||
model_d = NLayerDiscriminator(n_layers=5,
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
return model_d
|
||||
|
||||
|
||||
def get_generator(model_config):
|
||||
generator_name = model_config['g_name']
|
||||
if generator_name == 'resnet':
|
||||
model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_dropout=model_config['dropout'],
|
||||
n_blocks=model_config['blocks'],
|
||||
learn_residual=model_config['learn_residual'])
|
||||
elif generator_name == 'fpn_mobilenet':
|
||||
model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
elif generator_name == 'fpn_inception':
|
||||
# model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
# torch.save(model_g, 'mymodel.pth')
|
||||
model_g = torch.load('mymodel.pth')
|
||||
elif generator_name == 'fpn_inception_simple':
|
||||
model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
elif generator_name == 'fpn_dense':
|
||||
model_g = FPNDense()
|
||||
elif generator_name == 'unet_seresnext':
|
||||
model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
pretrained=model_config['pretrained'])
|
||||
else:
|
||||
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
||||
|
||||
return nn.DataParallel(model_g)
|
||||
|
||||
def get_generator_new(weights_path):
|
||||
|
||||
model_g = torch.load(weights_path+'mymodel.pth')
|
||||
|
||||
return nn.DataParallel(model_g)
|
||||
|
||||
def get_discriminator(model_config):
|
||||
discriminator_name = model_config['d_name']
|
||||
if discriminator_name == 'no_gan':
|
||||
model_d = None
|
||||
elif discriminator_name == 'patch_gan':
|
||||
model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
model_d = nn.DataParallel(model_d)
|
||||
elif discriminator_name == 'double_gan':
|
||||
patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
patch_gan = nn.DataParallel(patch_gan)
|
||||
full_gan = get_fullD(model_config)
|
||||
full_gan = nn.DataParallel(full_gan)
|
||||
model_d = {'patch': patch_gan,
|
||||
'full': full_gan}
|
||||
elif discriminator_name == 'multi_scale':
|
||||
model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
model_d = nn.DataParallel(model_d)
|
||||
else:
|
||||
raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
|
||||
|
||||
return model_d
|
||||
|
||||
|
||||
def get_nets(model_config):
|
||||
return get_generator(model_config), get_discriminator(model_config)
|
Binary file not shown.
@ -0,0 +1,430 @@
|
||||
from __future__ import print_function, division, absolute_import
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
|
||||
'se_resnext50_32x4d', 'se_resnext101_32x4d']
|
||||
|
||||
pretrained_settings = {
|
||||
'senet154': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet50': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet101': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet152': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext50_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext101_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
|
||||
padding=0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""
|
||||
Base class for bottlenecks that implements `forward()` method.
|
||||
"""
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out = self.se_module(out) + residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SEBottleneck(Bottleneck):
|
||||
"""
|
||||
Bottleneck for SENet154.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1)
|
||||
self.bn1 = nn.InstanceNorm2d(planes * 2, affine=False)
|
||||
self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
|
||||
stride=stride, padding=1, groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNetBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
|
||||
implementation and uses `stride=stride` in `conv1` and not in `conv2`
|
||||
(the latter is used in the torchvision implementation of ResNet).
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEResNetBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1,
|
||||
stride=stride)
|
||||
self.bn1 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
|
||||
groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNeXtBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None, base_width=4):
|
||||
super(SEResNeXtBottleneck, self).__init__()
|
||||
width = math.floor(planes * (base_width / 64)) * groups
|
||||
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1,
|
||||
stride=1)
|
||||
self.bn1 = nn.InstanceNorm2d(width, affine=False)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
|
||||
padding=1, groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(width, affine=False)
|
||||
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SENet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
|
||||
inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
downsample_padding=1, num_classes=1000):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
block (nn.Module): Bottleneck class.
|
||||
- For SENet154: SEBottleneck
|
||||
- For SE-ResNet models: SEResNetBottleneck
|
||||
- For SE-ResNeXt models: SEResNeXtBottleneck
|
||||
layers (list of ints): Number of residual blocks for 4 layers of the
|
||||
network (layer1...layer4).
|
||||
groups (int): Number of groups for the 3x3 convolution in each
|
||||
bottleneck block.
|
||||
- For SENet154: 64
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 32
|
||||
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
|
||||
- For all models: 16
|
||||
dropout_p (float or None): Drop probability for the Dropout layer.
|
||||
If `None` the Dropout layer is not used.
|
||||
- For SENet154: 0.2
|
||||
- For SE-ResNet models: None
|
||||
- For SE-ResNeXt models: None
|
||||
inplanes (int): Number of input channels for layer1.
|
||||
- For SENet154: 128
|
||||
- For SE-ResNet models: 64
|
||||
- For SE-ResNeXt models: 64
|
||||
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
|
||||
a single 7x7 convolution in layer0.
|
||||
- For SENet154: True
|
||||
- For SE-ResNet models: False
|
||||
- For SE-ResNeXt models: False
|
||||
downsample_kernel_size (int): Kernel size for downsampling convolutions
|
||||
in layer2, layer3 and layer4.
|
||||
- For SENet154: 3
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 1
|
||||
downsample_padding (int): Padding for downsampling convolutions in
|
||||
layer2, layer3 and layer4.
|
||||
- For SENet154: 1
|
||||
- For SE-ResNet models: 0
|
||||
- For SE-ResNeXt models: 0
|
||||
num_classes (int): Number of outputs in `last_linear` layer.
|
||||
- For all models: 1000
|
||||
"""
|
||||
super(SENet, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
if input_3x3:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1)),
|
||||
('bn1', nn.InstanceNorm2d(64, affine=False)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1)),
|
||||
('bn2', nn.InstanceNorm2d(64, affine=False)),
|
||||
('relu2', nn.ReLU(inplace=True)),
|
||||
('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1)),
|
||||
('bn3', nn.InstanceNorm2d(inplanes, affine=False)),
|
||||
('relu3', nn.ReLU(inplace=True)),
|
||||
]
|
||||
else:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
|
||||
padding=3)),
|
||||
('bn1', nn.InstanceNorm2d(inplanes, affine=False)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
]
|
||||
# To preserve compatibility with Caffe weights `ceil_mode=True`
|
||||
# is used instead of `padding=1`.
|
||||
layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
|
||||
ceil_mode=True)))
|
||||
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
|
||||
self.layer1 = self._make_layer(
|
||||
block,
|
||||
planes=64,
|
||||
blocks=layers[0],
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=1,
|
||||
downsample_padding=0
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
planes=128,
|
||||
blocks=layers[1],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
planes=512,
|
||||
blocks=layers[3],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.avg_pool = nn.AvgPool2d(7, stride=1)
|
||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||
downsample_kernel_size=1, downsample_padding=0):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=downsample_kernel_size, stride=stride,
|
||||
padding=downsample_padding),
|
||||
nn.InstanceNorm2d(planes * block.expansion, affine=False),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, groups, reduction, stride,
|
||||
downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups, reduction))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def features(self, x):
|
||||
x = self.layer0(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def logits(self, x):
|
||||
x = self.avg_pool(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
|
||||
|
||||
def initialize_pretrained_model(model, num_classes, settings):
|
||||
assert num_classes == settings['num_classes'], \
|
||||
'num_classes should be {}, but is {}'.format(
|
||||
settings['num_classes'], num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
|
||||
|
||||
def senet154(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
dropout_p=0.2, num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['senet154'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet50(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet50'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet101(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet101'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet152(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet152'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
Binary file not shown.
@ -0,0 +1,153 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.parallel
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from torch.nn import Sequential
|
||||
from collections import OrderedDict
|
||||
import torchvision
|
||||
from torch.nn import functional as F
|
||||
from senet import se_resnext50_32x4d
|
||||
|
||||
|
||||
def conv3x3(in_, out):
|
||||
return nn.Conv2d(in_, out, 3, padding=1)
|
||||
|
||||
|
||||
class ConvRelu(nn.Module):
|
||||
def __init__(self, in_, out):
|
||||
super(ConvRelu, self).__init__()
|
||||
self.conv = conv3x3(in_, out)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
class UNetSEResNext(nn.Module):
|
||||
|
||||
def __init__(self, num_classes=3, num_filters=32,
|
||||
pretrained=True, is_deconv=True):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
pretrain = 'imagenet' if pretrained is True else None
|
||||
self.encoder = se_resnext50_32x4d(num_classes=1000, pretrained=pretrain)
|
||||
bottom_channel_nr = 2048
|
||||
|
||||
self.conv1 = self.encoder.layer0
|
||||
#self.se_e1 = SCSEBlock(64)
|
||||
self.conv2 = self.encoder.layer1
|
||||
#self.se_e2 = SCSEBlock(64 * 4)
|
||||
self.conv3 = self.encoder.layer2
|
||||
#self.se_e3 = SCSEBlock(128 * 4)
|
||||
self.conv4 = self.encoder.layer3
|
||||
#self.se_e4 = SCSEBlock(256 * 4)
|
||||
self.conv5 = self.encoder.layer4
|
||||
#self.se_e5 = SCSEBlock(512 * 4)
|
||||
|
||||
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
|
||||
|
||||
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 2, is_deconv)
|
||||
#self.se_d5 = SCSEBlock(num_filters * 2)
|
||||
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 2, num_filters * 8, num_filters * 2, is_deconv)
|
||||
#self.se_d4 = SCSEBlock(num_filters * 2)
|
||||
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 2, num_filters * 4, num_filters * 2, is_deconv)
|
||||
#self.se_d3 = SCSEBlock(num_filters * 2)
|
||||
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2, num_filters * 2, is_deconv)
|
||||
#self.se_d2 = SCSEBlock(num_filters * 2)
|
||||
self.dec1 = DecoderBlockV(num_filters * 2, num_filters, num_filters * 2, is_deconv)
|
||||
#self.se_d1 = SCSEBlock(num_filters * 2)
|
||||
self.dec0 = ConvRelu(num_filters * 10, num_filters * 2)
|
||||
self.final = nn.Conv2d(num_filters * 2, num_classes, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
conv1 = self.conv1(x)
|
||||
#conv1 = self.se_e1(conv1)
|
||||
conv2 = self.conv2(conv1)
|
||||
#conv2 = self.se_e2(conv2)
|
||||
conv3 = self.conv3(conv2)
|
||||
#conv3 = self.se_e3(conv3)
|
||||
conv4 = self.conv4(conv3)
|
||||
#conv4 = self.se_e4(conv4)
|
||||
conv5 = self.conv5(conv4)
|
||||
#conv5 = self.se_e5(conv5)
|
||||
|
||||
center = self.center(conv5)
|
||||
dec5 = self.dec5(torch.cat([center, conv5], 1))
|
||||
#dec5 = self.se_d5(dec5)
|
||||
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
|
||||
#dec4 = self.se_d4(dec4)
|
||||
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
|
||||
#dec3 = self.se_d3(dec3)
|
||||
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
|
||||
#dec2 = self.se_d2(dec2)
|
||||
dec1 = self.dec1(dec2)
|
||||
#dec1 = self.se_d1(dec1)
|
||||
|
||||
f = torch.cat((
|
||||
dec1,
|
||||
F.upsample(dec2, scale_factor=2, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec3, scale_factor=4, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec4, scale_factor=8, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec5, scale_factor=16, mode='bilinear', align_corners=False),
|
||||
), 1)
|
||||
|
||||
dec0 = self.dec0(f)
|
||||
|
||||
return self.final(dec0)
|
||||
|
||||
class DecoderBlockV(nn.Module):
|
||||
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
|
||||
super(DecoderBlockV, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if is_deconv:
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
|
||||
padding=1),
|
||||
nn.InstanceNorm2d(out_channels, affine=False),
|
||||
nn.ReLU(inplace=True)
|
||||
|
||||
)
|
||||
else:
|
||||
self.block = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
ConvRelu(middle_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
|
||||
class DecoderCenter(nn.Module):
|
||||
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
|
||||
super(DecoderCenter, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
|
||||
if is_deconv:
|
||||
"""
|
||||
Paramaters for Deconvolution were chosen to avoid artifacts, following
|
||||
link https://distill.pub/2016/deconv-checkerboard/
|
||||
"""
|
||||
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
|
||||
padding=1),
|
||||
nn.InstanceNorm2d(out_channels, affine=False),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
else:
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
ConvRelu(middle_channels, out_channels)
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
Binary file not shown.
Binary file not shown.
@ -0,0 +1,108 @@
|
||||
import os
|
||||
from glob import glob
|
||||
# from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from fire import Fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from aug import get_normalize
|
||||
from models.networks import get_generator
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, weights_path, model_name=''):
|
||||
with open('config/config.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
model = get_generator(model_name or config['model'])
|
||||
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
|
||||
if torch.cuda.is_available():
|
||||
self.model = model.cuda()
|
||||
else:
|
||||
self.model = model
|
||||
self.model.train(True)
|
||||
# GAN inference should be in train mode to use actual stats in norm layers,
|
||||
# it's not a bug
|
||||
self.normalize_fn = get_normalize()
|
||||
|
||||
@staticmethod
|
||||
def _array_to_batch(x):
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
x = np.expand_dims(x, 0)
|
||||
return torch.from_numpy(x)
|
||||
|
||||
def _preprocess(self, x, mask):
|
||||
x, _ = self.normalize_fn(x, x)
|
||||
if mask is None:
|
||||
mask = np.ones_like(x, dtype=np.float32)
|
||||
else:
|
||||
mask = np.round(mask.astype('float32') / 255)
|
||||
|
||||
h, w, _ = x.shape
|
||||
block_size = 32
|
||||
min_height = (h // block_size + 1) * block_size
|
||||
min_width = (w // block_size + 1) * block_size
|
||||
|
||||
pad_params = {'mode': 'constant',
|
||||
'constant_values': 0,
|
||||
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
||||
}
|
||||
x = np.pad(x, **pad_params)
|
||||
mask = np.pad(mask, **pad_params)
|
||||
|
||||
return map(self._array_to_batch, (x, mask)), h, w
|
||||
|
||||
@staticmethod
|
||||
def _postprocess(x):
|
||||
x, = x
|
||||
x = x.detach().cpu().float().numpy()
|
||||
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return x.astype('uint8')
|
||||
|
||||
def __call__(self, img, mask, ignore_mask=True):
|
||||
(img, mask), h, w = self._preprocess(img, mask)
|
||||
with torch.no_grad():
|
||||
if torch.cuda.is_available():
|
||||
inputs = [img.cuda()]
|
||||
else:
|
||||
inputs = [img]
|
||||
if not ignore_mask:
|
||||
inputs += [mask]
|
||||
pred = self.model(*inputs)
|
||||
return self._postprocess(pred)[:h, :w, :]
|
||||
|
||||
def sorted_glob(pattern):
|
||||
return sorted(glob(pattern))
|
||||
|
||||
def main(img_pattern,
|
||||
mask_pattern = None,
|
||||
weights_path='best_fpn.h5',
|
||||
out_dir='submit/',
|
||||
side_by_side = False):
|
||||
|
||||
|
||||
imgs = sorted_glob(img_pattern)
|
||||
masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs]
|
||||
pairs = zip(imgs, masks)
|
||||
names = sorted([os.path.basename(x) for x in glob(img_pattern)])
|
||||
predictor = Predictor(weights_path=weights_path)
|
||||
|
||||
# os.makedirs(out_dir)
|
||||
for name, pair in tqdm(zip(names, pairs), total=len(names)):
|
||||
f_img, f_mask = pair
|
||||
img, mask = map(cv2.imread, (f_img, f_mask))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
pred = predictor(img, mask)
|
||||
if side_by_side:
|
||||
pred = np.hstack((img, pred))
|
||||
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(out_dir, name),
|
||||
pred)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Fire(main)
|
@ -0,0 +1,67 @@
|
||||
from models.networks import get_generator_new
|
||||
from aug import get_normalize
|
||||
import torch
|
||||
import numpy as np
|
||||
config={'project': 'deblur_gan', 'warmup_num': 3, 'optimizer': {'lr': 0.0001, 'name': 'adam'}, 'val': {'preload': False, 'bounds': [0.9, 1], 'crop': 'center', 'files_b': '/datasets/my_dataset/**/*.jpg', 'files_a': '/datasets/my_dataset/**/*.jpg', 'scope': 'geometric', 'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5}, {'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'}, {'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'}, {'name': 'hsv_shift'}, {'name': 'sharpen'}], 'preload_size': 0, 'size': 256}, 'val_batches_per_epoch': 100, 'num_epochs': 200, 'batch_size': 1, 'experiment_desc': 'fpn', 'train_batches_per_epoch': 1000, 'train': {'preload': False, 'bounds': [0, 0.9], 'crop': 'random', 'files_b': '/datasets/my_dataset/**/*.jpg', 'files_a': '/datasets/my_dataset/**/*.jpg', 'preload_size': 0, 'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5}, {'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'}, {'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'}, {'name': 'hsv_shift'}, {'name': 'sharpen'}], 'scope': 'geometric', 'size': 256}, 'scheduler': {'min_lr': 1e-07, 'name': 'linear', 'start_epoch': 50}, 'image_size': [256, 256], 'phase': 'train', 'model': {'d_name': 'double_gan', 'disc_loss': 'wgan-gp', 'blocks': 9, 'content_loss': 'perceptual', 'adv_lambda': 0.001, 'dropout': True, 'g_name': 'fpn_inception', 'd_layers': 3, 'learn_residual': True, 'norm_layer': 'instance'}}
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, weights_path, model_name=''):
|
||||
# model = get_generator(model_name or config['model'])
|
||||
model = get_generator_new(weights_path[0:-11])
|
||||
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
|
||||
if torch.cuda.is_available():
|
||||
self.model = model.cuda()
|
||||
else:
|
||||
self.model = model
|
||||
self.model.train(True)
|
||||
# GAN inference should be in train mode to use actual stats in norm layers,
|
||||
# it's not a bug
|
||||
self.normalize_fn = get_normalize()
|
||||
|
||||
@staticmethod
|
||||
def _array_to_batch(x):
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
x = np.expand_dims(x, 0)
|
||||
return torch.from_numpy(x)
|
||||
|
||||
def _preprocess(self, x, mask):
|
||||
x, _ = self.normalize_fn(x, x)
|
||||
if mask is None:
|
||||
mask = np.ones_like(x, dtype=np.float32)
|
||||
else:
|
||||
mask = np.round(mask.astype('float32') / 255)
|
||||
|
||||
h, w, _ = x.shape
|
||||
block_size = 32
|
||||
min_height = (h // block_size + 1) * block_size
|
||||
min_width = (w // block_size + 1) * block_size
|
||||
|
||||
pad_params = {'mode': 'constant',
|
||||
'constant_values': 0,
|
||||
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
||||
}
|
||||
x = np.pad(x, **pad_params)
|
||||
mask = np.pad(mask, **pad_params)
|
||||
|
||||
return map(self._array_to_batch, (x, mask)), h, w
|
||||
|
||||
@staticmethod
|
||||
def _postprocess(x):
|
||||
x, = x
|
||||
x = x.detach().cpu().float().numpy()
|
||||
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return x.astype('uint8')
|
||||
|
||||
def __call__(self, img, mask, ignore_mask=True):
|
||||
(img, mask), h, w = self._preprocess(img, mask)
|
||||
with torch.no_grad():
|
||||
if torch.cuda.is_available():
|
||||
inputs = [img.cuda()]
|
||||
else:
|
||||
inputs = [img]
|
||||
if not ignore_mask:
|
||||
inputs += [mask]
|
||||
pred = self.model(*inputs)
|
||||
return self._postprocess(pred)[:h, :w, :]
|
||||
|
Binary file not shown.
@ -0,0 +1,13 @@
|
||||
torch==1.0.1
|
||||
torchvision
|
||||
pretrainedmodels
|
||||
numpy
|
||||
opencv-python-headless
|
||||
joblib
|
||||
albumentations
|
||||
scikit-image
|
||||
tqdm
|
||||
glog
|
||||
tensorboardx
|
||||
fire
|
||||
# this file is not ready yet
|
@ -0,0 +1,59 @@
|
||||
import math
|
||||
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
|
||||
class WarmRestart(lr_scheduler.CosineAnnealingLR):
|
||||
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
|
||||
|
||||
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
|
||||
This can't support scheduler.step(epoch). please keep epoch=None.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
|
||||
"""implements SGDR
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
T_max : int
|
||||
Maximum number of epochs.
|
||||
T_mult : int
|
||||
Multiplicative factor of T_max.
|
||||
eta_min : int
|
||||
Minimum learning rate. Default: 0.
|
||||
last_epoch : int
|
||||
The index of last epoch. Default: -1.
|
||||
"""
|
||||
self.T_mult = T_mult
|
||||
super().__init__(optimizer, T_max, eta_min, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == self.T_max:
|
||||
self.last_epoch = 0
|
||||
self.T_max *= self.T_mult
|
||||
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
|
||||
base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class LinearDecay(lr_scheduler._LRScheduler):
|
||||
"""This class implements LinearDecay
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
|
||||
"""implements LinearDecay
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
|
||||
"""
|
||||
self.num_epochs = num_epochs
|
||||
self.start_epoch = start_epoch
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.start_epoch:
|
||||
return self.base_lrs
|
||||
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
|
||||
base_lr in self.base_lrs]
|
@ -0,0 +1,3 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
python3 -m unittest discover $(pwd)
|
@ -0,0 +1,20 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from aug import get_transforms
|
||||
|
||||
|
||||
class AugTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def make_images():
|
||||
img = (np.random.rand(100, 100, 3) * 255).astype('uint8')
|
||||
return img.copy(), img.copy()
|
||||
|
||||
def test_aug(self):
|
||||
for scope in ('strong', 'weak'):
|
||||
for crop in ('random', 'center'):
|
||||
aug_pipeline = get_transforms(80, scope=scope, crop=crop)
|
||||
a, b = self.make_images()
|
||||
a, b = aug_pipeline(a, b)
|
||||
np.testing.assert_allclose(a, b)
|
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import unittest
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import PairedDataset
|
||||
|
||||
|
||||
def make_img():
|
||||
return (np.random.rand(100, 100, 3) * 255).astype('uint8')
|
||||
|
||||
|
||||
class AugTest(unittest.TestCase):
|
||||
tmp_dir = mkdtemp()
|
||||
raw = os.path.join(tmp_dir, 'raw')
|
||||
gt = os.path.join(tmp_dir, 'gt')
|
||||
|
||||
def setUp(self):
|
||||
for d in (self.raw, self.gt):
|
||||
os.makedirs(d)
|
||||
|
||||
for i in range(5):
|
||||
for d in (self.raw, self.gt):
|
||||
img = make_img()
|
||||
cv2.imwrite(os.path.join(d, f'{i}.png'), img)
|
||||
|
||||
def tearDown(self):
|
||||
rmtree(self.tmp_dir)
|
||||
|
||||
def dataset_gen(self, equal=True):
|
||||
base_config = {'files_a': os.path.join(self.raw, '*.png'),
|
||||
'files_b': os.path.join(self.raw if equal else self.gt, '*.png'),
|
||||
'size': 32,
|
||||
}
|
||||
for b in ([0, 1], [0, 0.9]):
|
||||
for scope in ('strong', 'weak'):
|
||||
for crop in ('random', 'center'):
|
||||
for preload in (0, 1):
|
||||
for preload_size in (0, 64):
|
||||
config = base_config.copy()
|
||||
config['bounds'] = b
|
||||
config['scope'] = scope
|
||||
config['crop'] = crop
|
||||
config['preload'] = preload
|
||||
config['preload_size'] = preload_size
|
||||
config['verbose'] = False
|
||||
dataset = PairedDataset.from_config(config)
|
||||
yield dataset
|
||||
|
||||
def test_equal_datasets(self):
|
||||
for dataset in self.dataset_gen(equal=True):
|
||||
dataloader = DataLoader(dataset=dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dataloader = iter(dataloader)
|
||||
batch = next(dataloader)
|
||||
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
|
||||
|
||||
np.testing.assert_allclose(a, b)
|
||||
|
||||
def test_datasets(self):
|
||||
for dataset in self.dataset_gen(equal=False):
|
||||
dataloader = DataLoader(dataset=dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dataloader = iter(dataloader)
|
||||
batch = next(dataloader)
|
||||
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
|
||||
|
||||
assert not np.all(a == b), 'images should not be the same'
|
@ -0,0 +1,90 @@
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torchvision import models, transforms
|
||||
from torch.autograd import Variable
|
||||
import shutil
|
||||
import glob
|
||||
import tqdm
|
||||
from util.metrics import PSNR
|
||||
from albumentations import Compose, CenterCrop, PadIfNeeded
|
||||
from PIL import Image
|
||||
from ssim.ssimlib import SSIM
|
||||
from models.networks import get_generator
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_dirs(path):
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_gt_image(path):
|
||||
dir, filename = os.path.split(path)
|
||||
base, seq = os.path.split(dir)
|
||||
base, _ = os.path.split(base)
|
||||
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def test_image(model, image_path):
|
||||
img_transforms = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
])
|
||||
size_transform = Compose([
|
||||
PadIfNeeded(736, 1280)
|
||||
])
|
||||
crop = CenterCrop(720, 1280)
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_s = size_transform(image=img)['image']
|
||||
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
|
||||
img_tensor = img_transforms(img_tensor)
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
|
||||
result_image = model(img_tensor)
|
||||
result_image = result_image[0].cpu().float().numpy()
|
||||
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
result_image = crop(image=result_image)['image']
|
||||
result_image = result_image.astype('uint8')
|
||||
gt_image = get_gt_image(image_path)
|
||||
_, filename = os.path.split(image_path)
|
||||
psnr = PSNR(result_image, gt_image)
|
||||
pilFake = Image.fromarray(result_image)
|
||||
pilReal = Image.fromarray(gt_image)
|
||||
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
|
||||
return psnr, ssim
|
||||
|
||||
|
||||
def test(model, files):
|
||||
psnr = 0
|
||||
ssim = 0
|
||||
for file in tqdm.tqdm(files):
|
||||
cur_psnr, cur_ssim = test_image(model, file)
|
||||
psnr += cur_psnr
|
||||
ssim += cur_ssim
|
||||
print("PSNR = {}".format(psnr / len(files)))
|
||||
print("SSIM = {}".format(ssim / len(files)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path)['model'])
|
||||
model = model.cuda()
|
||||
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
|
||||
test(model, filenames)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue