mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
13 KiB
Python
331 lines
13 KiB
Python
5 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import init
|
||
|
import functools
|
||
|
from torch.autograd import Variable
|
||
|
import numpy as np
|
||
|
from fpn_mobilenet import FPNMobileNet
|
||
|
from fpn_inception import FPNInception
|
||
|
# from fpn_inception_simple import FPNInceptionSimple
|
||
|
from unet_seresnext import UNetSEResNext
|
||
|
from fpn_densenet import FPNDense
|
||
|
###############################################################################
|
||
|
# Functions
|
||
|
###############################################################################
|
||
|
|
||
|
|
||
|
def get_norm_layer(norm_type='instance'):
|
||
|
if norm_type == 'batch':
|
||
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||
|
elif norm_type == 'instance':
|
||
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
||
|
else:
|
||
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||
|
return norm_layer
|
||
|
|
||
|
##############################################################################
|
||
|
# Classes
|
||
|
##############################################################################
|
||
|
|
||
|
|
||
|
# Defines the generator that consists of Resnet blocks between a few
|
||
|
# downsampling/upsampling operations.
|
||
|
# Code and idea originally from Justin Johnson's architecture.
|
||
|
# https://github.com/jcjohnson/fast-neural-style/
|
||
|
class ResnetGenerator(nn.Module):
|
||
|
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
|
||
|
assert(n_blocks >= 0)
|
||
|
super(ResnetGenerator, self).__init__()
|
||
|
self.input_nc = input_nc
|
||
|
self.output_nc = output_nc
|
||
|
self.ngf = ngf
|
||
|
self.use_parallel = use_parallel
|
||
|
self.learn_residual = learn_residual
|
||
|
if type(norm_layer) == functools.partial:
|
||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||
|
|
||
|
model = [nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
||
|
bias=use_bias),
|
||
|
norm_layer(ngf),
|
||
|
nn.ReLU(True)]
|
||
|
|
||
|
n_downsampling = 2
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2**i
|
||
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
||
|
stride=2, padding=1, bias=use_bias),
|
||
|
norm_layer(ngf * mult * 2),
|
||
|
nn.ReLU(True)]
|
||
|
|
||
|
mult = 2**n_downsampling
|
||
|
for i in range(n_blocks):
|
||
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
||
|
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2**(n_downsampling - i)
|
||
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
||
|
kernel_size=3, stride=2,
|
||
|
padding=1, output_padding=1,
|
||
|
bias=use_bias),
|
||
|
norm_layer(int(ngf * mult / 2)),
|
||
|
nn.ReLU(True)]
|
||
|
model += [nn.ReflectionPad2d(3)]
|
||
|
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
||
|
model += [nn.Tanh()]
|
||
|
|
||
|
self.model = nn.Sequential(*model)
|
||
|
|
||
|
def forward(self, input):
|
||
|
output = self.model(input)
|
||
|
if self.learn_residual:
|
||
|
output = input + output
|
||
|
output = torch.clamp(output,min = -1,max = 1)
|
||
|
return output
|
||
|
|
||
|
|
||
|
# Define a resnet block
|
||
|
class ResnetBlock(nn.Module):
|
||
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||
|
super(ResnetBlock, self).__init__()
|
||
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
||
|
|
||
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||
|
conv_block = []
|
||
|
p = 0
|
||
|
if padding_type == 'reflect':
|
||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||
|
elif padding_type == 'replicate':
|
||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||
|
elif padding_type == 'zero':
|
||
|
p = 1
|
||
|
else:
|
||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||
|
|
||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||
|
norm_layer(dim),
|
||
|
nn.ReLU(True)]
|
||
|
if use_dropout:
|
||
|
conv_block += [nn.Dropout(0.5)]
|
||
|
|
||
|
p = 0
|
||
|
if padding_type == 'reflect':
|
||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||
|
elif padding_type == 'replicate':
|
||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||
|
elif padding_type == 'zero':
|
||
|
p = 1
|
||
|
else:
|
||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||
|
norm_layer(dim)]
|
||
|
|
||
|
return nn.Sequential(*conv_block)
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = x + self.conv_block(x)
|
||
|
return out
|
||
|
|
||
|
|
||
|
class DicsriminatorTail(nn.Module):
|
||
|
def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||
|
super(DicsriminatorTail, self).__init__()
|
||
|
self.use_parallel = use_parallel
|
||
|
if type(norm_layer) == functools.partial:
|
||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||
|
|
||
|
kw = 4
|
||
|
padw = int(np.ceil((kw-1)/2))
|
||
|
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n_layers, 8)
|
||
|
sequence = [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||
|
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||
|
|
||
|
self.model = nn.Sequential(*sequence)
|
||
|
|
||
|
def forward(self, input):
|
||
|
return self.model(input)
|
||
|
|
||
|
|
||
|
class MultiScaleDiscriminator(nn.Module):
|
||
|
def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||
|
super(MultiScaleDiscriminator, self).__init__()
|
||
|
self.use_parallel = use_parallel
|
||
|
if type(norm_layer) == functools.partial:
|
||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||
|
|
||
|
kw = 4
|
||
|
padw = int(np.ceil((kw-1)/2))
|
||
|
sequence = [
|
||
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
nf_mult = 1
|
||
|
for n in range(1, 3):
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||
|
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
self.scale_one = nn.Sequential(*sequence)
|
||
|
self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
|
||
|
nf_mult_prev = 4
|
||
|
nf_mult = 8
|
||
|
|
||
|
self.scale_two = nn.Sequential(
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||
|
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True))
|
||
|
nf_mult_prev = nf_mult
|
||
|
self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
|
||
|
self.scale_three = nn.Sequential(
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True))
|
||
|
self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
|
||
|
|
||
|
def forward(self, input):
|
||
|
x = self.scale_one(input)
|
||
|
x_1 = self.first_tail(x)
|
||
|
x = self.scale_two(x)
|
||
|
x_2 = self.second_tail(x)
|
||
|
x = self.scale_three(x)
|
||
|
x = self.third_tail(x)
|
||
|
return [x_1, x_2, x]
|
||
|
|
||
|
|
||
|
# Defines the PatchGAN discriminator with the specified arguments.
|
||
|
class NLayerDiscriminator(nn.Module):
|
||
|
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
|
||
|
super(NLayerDiscriminator, self).__init__()
|
||
|
self.use_parallel = use_parallel
|
||
|
if type(norm_layer) == functools.partial:
|
||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||
|
|
||
|
kw = 4
|
||
|
padw = int(np.ceil((kw-1)/2))
|
||
|
sequence = [
|
||
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
nf_mult = 1
|
||
|
for n in range(1, n_layers):
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||
|
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n_layers, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||
|
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||
|
|
||
|
if use_sigmoid:
|
||
|
sequence += [nn.Sigmoid()]
|
||
|
|
||
|
self.model = nn.Sequential(*sequence)
|
||
|
|
||
|
def forward(self, input):
|
||
|
return self.model(input)
|
||
|
|
||
|
|
||
|
def get_fullD(model_config):
|
||
|
model_d = NLayerDiscriminator(n_layers=5,
|
||
|
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||
|
use_sigmoid=False)
|
||
|
return model_d
|
||
|
|
||
|
|
||
|
def get_generator(model_config):
|
||
|
generator_name = model_config['g_name']
|
||
|
if generator_name == 'resnet':
|
||
|
model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||
|
use_dropout=model_config['dropout'],
|
||
|
n_blocks=model_config['blocks'],
|
||
|
learn_residual=model_config['learn_residual'])
|
||
|
elif generator_name == 'fpn_mobilenet':
|
||
|
model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||
|
elif generator_name == 'fpn_inception':
|
||
|
# model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||
|
# torch.save(model_g, 'mymodel.pth')
|
||
|
model_g = torch.load('mymodel.pth')
|
||
|
elif generator_name == 'fpn_inception_simple':
|
||
|
model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||
|
elif generator_name == 'fpn_dense':
|
||
|
model_g = FPNDense()
|
||
|
elif generator_name == 'unet_seresnext':
|
||
|
model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||
|
pretrained=model_config['pretrained'])
|
||
|
else:
|
||
|
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
||
|
|
||
|
return nn.DataParallel(model_g)
|
||
|
|
||
|
def get_generator_new(weights_path):
|
||
|
|
||
|
model_g = torch.load(weights_path+'mymodel.pth')
|
||
|
|
||
|
return nn.DataParallel(model_g)
|
||
|
|
||
|
def get_discriminator(model_config):
|
||
|
discriminator_name = model_config['d_name']
|
||
|
if discriminator_name == 'no_gan':
|
||
|
model_d = None
|
||
|
elif discriminator_name == 'patch_gan':
|
||
|
model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||
|
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||
|
use_sigmoid=False)
|
||
|
model_d = nn.DataParallel(model_d)
|
||
|
elif discriminator_name == 'double_gan':
|
||
|
patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||
|
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||
|
use_sigmoid=False)
|
||
|
patch_gan = nn.DataParallel(patch_gan)
|
||
|
full_gan = get_fullD(model_config)
|
||
|
full_gan = nn.DataParallel(full_gan)
|
||
|
model_d = {'patch': patch_gan,
|
||
|
'full': full_gan}
|
||
|
elif discriminator_name == 'multi_scale':
|
||
|
model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||
|
model_d = nn.DataParallel(model_d)
|
||
|
else:
|
||
|
raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
|
||
|
|
||
|
return model_d
|
||
|
|
||
|
|
||
|
def get_nets(model_config):
|
||
|
return get_generator(model_config), get_discriminator(model_config)
|