You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GIMP-ML/gimp-plugins/DeblurGANv2/models/networks.py

331 lines
13 KiB
Python

5 years ago
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
import numpy as np
from fpn_mobilenet import FPNMobileNet
from fpn_inception import FPNInception
# from fpn_inception_simple import FPNInceptionSimple
from unet_seresnext import UNetSEResNext
from fpn_densenet import FPNDense
###############################################################################
# Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
##############################################################################
# Classes
##############################################################################
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.use_parallel = use_parallel
self.learn_residual = learn_residual
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
output = self.model(input)
if self.learn_residual:
output = input + output
output = torch.clamp(output,min = -1,max = 1)
return output
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class DicsriminatorTail(nn.Module):
def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
super(DicsriminatorTail, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence = [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
class MultiScaleDiscriminator(nn.Module):
def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
super(MultiScaleDiscriminator, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, 3):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
self.scale_one = nn.Sequential(*sequence)
self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
nf_mult_prev = 4
nf_mult = 8
self.scale_two = nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True))
nf_mult_prev = nf_mult
self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
self.scale_three = nn.Sequential(
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True))
self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
def forward(self, input):
x = self.scale_one(input)
x_1 = self.first_tail(x)
x = self.scale_two(x)
x_2 = self.second_tail(x)
x = self.scale_three(x)
x = self.third_tail(x)
return [x_1, x_2, x]
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
super(NLayerDiscriminator, self).__init__()
self.use_parallel = use_parallel
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = int(np.ceil((kw-1)/2))
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
def get_fullD(model_config):
model_d = NLayerDiscriminator(n_layers=5,
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
return model_d
def get_generator(model_config):
generator_name = model_config['g_name']
if generator_name == 'resnet':
model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_dropout=model_config['dropout'],
n_blocks=model_config['blocks'],
learn_residual=model_config['learn_residual'])
elif generator_name == 'fpn_mobilenet':
model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
elif generator_name == 'fpn_inception':
# model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
# torch.save(model_g, 'mymodel.pth')
model_g = torch.load('mymodel.pth')
elif generator_name == 'fpn_inception_simple':
model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
elif generator_name == 'fpn_dense':
model_g = FPNDense()
elif generator_name == 'unet_seresnext':
model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
pretrained=model_config['pretrained'])
else:
raise ValueError("Generator Network [%s] not recognized." % generator_name)
return nn.DataParallel(model_g)
def get_generator_new(weights_path):
model_g = torch.load(weights_path+'mymodel.pth')
return nn.DataParallel(model_g)
def get_discriminator(model_config):
discriminator_name = model_config['d_name']
if discriminator_name == 'no_gan':
model_d = None
elif discriminator_name == 'patch_gan':
model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
model_d = nn.DataParallel(model_d)
elif discriminator_name == 'double_gan':
patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
use_sigmoid=False)
patch_gan = nn.DataParallel(patch_gan)
full_gan = get_fullD(model_config)
full_gan = nn.DataParallel(full_gan)
model_d = {'patch': patch_gan,
'full': full_gan}
elif discriminator_name == 'multi_scale':
model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
model_d = nn.DataParallel(model_d)
else:
raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
return model_d
def get_nets(model_config):
return get_generator(model_config), get_discriminator(model_config)