From a3fcc0cfc405d1c1dc637a59d9b73d22ae8acfcc Mon Sep 17 00:00:00 2001 From: Kritik Soman Date: Sat, 17 Oct 2020 18:18:18 +0530 Subject: [PATCH] octUpdate --- gimp-plugins/EnlightenGAN/.idea/.gitignore | 3 + .../EnlightenGAN/.idea/EnlightenGAN.iml | 12 + .../inspectionProfiles/profiles_settings.xml | 6 + gimp-plugins/EnlightenGAN/.idea/misc.xml | 4 + gimp-plugins/EnlightenGAN/.idea/modules.xml | 8 + gimp-plugins/EnlightenGAN/.idea/vcs.xml | 6 + gimp-plugins/EnlightenGAN/License | 58 + .../EnlightenGAN/data/base_dataset.py | 50 + gimp-plugins/EnlightenGAN/lib/__init__.py | 0 gimp-plugins/EnlightenGAN/lib/nn/__init__.py | 2 + .../EnlightenGAN/lib/nn/modules/__init__.py | 12 + .../EnlightenGAN/lib/nn/modules/batchnorm.py | 329 +++++ .../EnlightenGAN/lib/nn/modules/comm.py | 131 ++ .../EnlightenGAN/lib/nn/modules/replicate.py | 94 ++ .../modules/tests/test_numeric_batchnorm.py | 56 + .../nn/modules/tests/test_sync_batchnorm.py | 111 ++ .../EnlightenGAN/lib/nn/modules/unittest.py | 29 + .../EnlightenGAN/lib/nn/parallel/__init__.py | 1 + .../lib/nn/parallel/data_parallel.py | 112 ++ gimp-plugins/EnlightenGAN/models/__init__.py | 0 .../EnlightenGAN/models/base_model.py | 61 + gimp-plugins/EnlightenGAN/models/models.py | 38 + gimp-plugins/EnlightenGAN/models/networks.py | 1181 +++++++++++++++++ .../EnlightenGAN/models/single_model.py | 496 +++++++ gimp-plugins/EnlightenGAN/util/__init__.py | 0 gimp-plugins/EnlightenGAN/util/image_pool.py | 32 + gimp-plugins/EnlightenGAN/util/util.py | 182 +++ gimp-plugins/enlighten.py | 107 ++ gimp-plugins/syncWeights.py | 11 + 29 files changed, 3132 insertions(+) create mode 100755 gimp-plugins/EnlightenGAN/.idea/.gitignore create mode 100755 gimp-plugins/EnlightenGAN/.idea/EnlightenGAN.iml create mode 100755 gimp-plugins/EnlightenGAN/.idea/inspectionProfiles/profiles_settings.xml create mode 100755 gimp-plugins/EnlightenGAN/.idea/misc.xml create mode 100755 gimp-plugins/EnlightenGAN/.idea/modules.xml create mode 100755 gimp-plugins/EnlightenGAN/.idea/vcs.xml create mode 100755 gimp-plugins/EnlightenGAN/License create mode 100755 gimp-plugins/EnlightenGAN/data/base_dataset.py create mode 100755 gimp-plugins/EnlightenGAN/lib/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/batchnorm.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/comm.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/replicate.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_numeric_batchnorm.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_sync_batchnorm.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/modules/unittest.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/parallel/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/lib/nn/parallel/data_parallel.py create mode 100755 gimp-plugins/EnlightenGAN/models/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/models/base_model.py create mode 100755 gimp-plugins/EnlightenGAN/models/models.py create mode 100755 gimp-plugins/EnlightenGAN/models/networks.py create mode 100755 gimp-plugins/EnlightenGAN/models/single_model.py create mode 100755 gimp-plugins/EnlightenGAN/util/__init__.py create mode 100755 gimp-plugins/EnlightenGAN/util/image_pool.py create mode 100755 gimp-plugins/EnlightenGAN/util/util.py create mode 100755 gimp-plugins/enlighten.py diff --git a/gimp-plugins/EnlightenGAN/.idea/.gitignore b/gimp-plugins/EnlightenGAN/.idea/.gitignore new file mode 100755 index 0000000..26d3352 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/gimp-plugins/EnlightenGAN/.idea/EnlightenGAN.iml b/gimp-plugins/EnlightenGAN/.idea/EnlightenGAN.iml new file mode 100755 index 0000000..e7b47d9 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/EnlightenGAN.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/gimp-plugins/EnlightenGAN/.idea/inspectionProfiles/profiles_settings.xml b/gimp-plugins/EnlightenGAN/.idea/inspectionProfiles/profiles_settings.xml new file mode 100755 index 0000000..105ce2d --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/gimp-plugins/EnlightenGAN/.idea/misc.xml b/gimp-plugins/EnlightenGAN/.idea/misc.xml new file mode 100755 index 0000000..0aa72d0 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/gimp-plugins/EnlightenGAN/.idea/modules.xml b/gimp-plugins/EnlightenGAN/.idea/modules.xml new file mode 100755 index 0000000..345b285 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/gimp-plugins/EnlightenGAN/.idea/vcs.xml b/gimp-plugins/EnlightenGAN/.idea/vcs.xml new file mode 100755 index 0000000..b2bdec2 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/gimp-plugins/EnlightenGAN/License b/gimp-plugins/EnlightenGAN/License new file mode 100755 index 0000000..2920867 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/License @@ -0,0 +1,58 @@ +Copyright (c) 2019, Yifan Jiang and Zhangyang Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR EnlightenGAN -------------------------------- +BSD License + +For EnlightenGAN software +Copyright (c) 2019, Yifan Jiang and Zhangyang Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/gimp-plugins/EnlightenGAN/data/base_dataset.py b/gimp-plugins/EnlightenGAN/data/base_dataset.py new file mode 100755 index 0000000..9d7acac --- /dev/null +++ b/gimp-plugins/EnlightenGAN/data/base_dataset.py @@ -0,0 +1,50 @@ +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +import random + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + zoom = 1 + 0.1*radom.randint(0,4) + osize = [int(400*zoom), int(600*zoom)] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + # elif opt.resize_or_crop == 'no': + # osize = [384, 512] + # transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) diff --git a/gimp-plugins/EnlightenGAN/lib/__init__.py b/gimp-plugins/EnlightenGAN/lib/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/gimp-plugins/EnlightenGAN/lib/nn/__init__.py b/gimp-plugins/EnlightenGAN/lib/nn/__init__.py new file mode 100755 index 0000000..98a9637 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/__init__.py @@ -0,0 +1,2 @@ +from .modules import * +from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/__init__.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/__init__.py new file mode 100755 index 0000000..bc8709d --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# File : __init__.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d +from .replicate import DataParallelWithCallback, patch_replication_callback diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/batchnorm.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/batchnorm.py new file mode 100755 index 0000000..1831896 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/batchnorm.py @@ -0,0 +1,329 @@ +# -*- coding: utf-8 -*- +# File : batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import collections + +import torch +import torch.nn.functional as F + +from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast + +from .comm import SyncMaster + +__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] + + +def _sum_ft(tensor): + """sum over the first and last dimention""" + return tensor.sum(dim=0).sum(dim=-1) + + +def _unsqueeze_ft(tensor): + """add new dementions at the front and the tail""" + return tensor.unsqueeze(0).unsqueeze(-1) + + +_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) +_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) + + +class _SynchronizedBatchNorm(_BatchNorm): + def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): + super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) + + self._sync_master = SyncMaster(self._data_parallel_master) + + self._is_parallel = False + self._parallel_id = None + self._slave_pipe = None + + # customed batch norm statistics + self._moving_average_fraction = 1. - momentum + self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) + self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) + self.register_buffer('_running_iter', torch.ones(1)) + self._tmp_running_mean = self.running_mean.clone() * self._running_iter + self._tmp_running_var = self.running_var.clone() * self._running_iter + + def forward(self, input): + # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. + if not (self._is_parallel and self.training): + return F.batch_norm( + input, self.running_mean, self.running_var, self.weight, self.bias, + self.training, self.momentum, self.eps) + + # Resize the input to (B, C, -1). + input_shape = input.size() + input = input.view(input.size(0), self.num_features, -1) + + # Compute the sum and square-sum. + sum_size = input.size(0) * input.size(2) + input_sum = _sum_ft(input) + input_ssum = _sum_ft(input ** 2) + + # Reduce-and-broadcast the statistics. + if self._parallel_id == 0: + mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) + else: + mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) + + # Compute the output. + if self.affine: + # MJY:: Fuse the multiplication for speed. + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) + else: + output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) + + # Reshape it. + return output.view(input_shape) + + def __data_parallel_replicate__(self, ctx, copy_id): + self._is_parallel = True + self._parallel_id = copy_id + + # parallel_id == 0 means master device. + if self._parallel_id == 0: + ctx.sync_master = self._sync_master + else: + self._slave_pipe = ctx.sync_master.register_slave(copy_id) + + def _data_parallel_master(self, intermediates): + """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) + + to_reduce = [i[1][:2] for i in intermediates] + to_reduce = [j for i in to_reduce for j in i] # flatten + target_gpus = [i[1].sum.get_device() for i in intermediates] + + sum_size = sum([i[1].sum_size for i in intermediates]) + sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) + + mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) + + broadcasted = Broadcast.apply(target_gpus, mean, inv_std) + + outputs = [] + for i, rec in enumerate(intermediates): + outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) + + return outputs + + def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): + """return *dest* by `dest := dest*alpha + delta*beta + bias`""" + return dest * alpha + delta * beta + bias + + def _compute_mean_std(self, sum_, ssum, size): + """Compute the mean and standard-deviation with sum and square-sum. This method + also maintains the moving average on the master device.""" + assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' + mean = sum_ / size + sumvar = ssum - sum_ * mean + unbias_var = sumvar / (size - 1) + bias_var = sumvar / size + + self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) + self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) + self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) + + self.running_mean = self._tmp_running_mean / self._running_iter + self.running_var = self._tmp_running_var / self._running_iter + + return mean, bias_var.clamp(self.eps) ** -0.5 + + +class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): + r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a + mini-batch. + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm1d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + + Args: + num_features: num_features from an expected input of size + `batch_size x num_features [x width]` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C)` or :math:`(N, C, L)` + - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm1d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 2 and input.dim() != 3: + raise ValueError('expected 2D or 3D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm1d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch + of 3d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm2d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, H, W)` + - Output: :math:`(N, C, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm2d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 4: + raise ValueError('expected 4D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm2d, self)._check_input_dim(input) + + +class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): + r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch + of 4d inputs + + .. math:: + + y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta + + This module differs from the built-in PyTorch BatchNorm3d as the mean and + standard-deviation are reduced across all devices during training. + + For example, when one uses `nn.DataParallel` to wrap the network during + training, PyTorch's implementation normalize the tensor on each device using + the statistics only on that device, which accelerated the computation and + is also easy to implement, but the statistics might be inaccurate. + Instead, in this synchronized version, the statistics will be computed + over all training samples distributed on multiple devices. + + Note that, for one-GPU or CPU-only case, this module behaves exactly same + as the built-in PyTorch implementation. + + The mean and standard-deviation are calculated per-dimension over + the mini-batches and gamma and beta are learnable parameter vectors + of size C (where C is the input size). + + During training, this layer keeps a running estimate of its computed mean + and variance. The running sum is kept with a default momentum of 0.1. + + During evaluation, this running mean/variance is used for normalization. + + Because the BatchNorm is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm + or Spatio-temporal BatchNorm + + Args: + num_features: num_features from an expected input of + size batch_size x num_features x depth x height x width + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Default: 0.1 + affine: a boolean value that when set to ``True``, gives the layer learnable + affine parameters. Default: ``True`` + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + >>> # With Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100) + >>> # Without Learnable Parameters + >>> m = SynchronizedBatchNorm3d(100, affine=False) + >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) + >>> output = m(input) + """ + + def _check_input_dim(self, input): + if input.dim() != 5: + raise ValueError('expected 5D input (got {}D input)' + .format(input.dim())) + super(SynchronizedBatchNorm3d, self)._check_input_dim(input) diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/comm.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/comm.py new file mode 100755 index 0000000..b64bf6b --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/comm.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +# File : comm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import queue +import collections +import threading + +__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] + + +class FutureResult(object): + """A thread-safe future implementation. Used only as one-to-one pipe.""" + + def __init__(self): + self._result = None + self._lock = threading.Lock() + self._cond = threading.Condition(self._lock) + + def put(self, result): + with self._lock: + assert self._result is None, 'Previous result has\'t been fetched.' + self._result = result + self._cond.notify() + + def get(self): + with self._lock: + if self._result is None: + self._cond.wait() + + res = self._result + self._result = None + return res + + +_MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) +_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) + + +class SlavePipe(_SlavePipeBase): + """Pipe for master-slave communication.""" + + def run_slave(self, msg): + self.queue.put((self.identifier, msg)) + ret = self.result.get() + self.queue.put(True) + return ret + + +class SyncMaster(object): + """An abstract `SyncMaster` object. + + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should + call `register(id)` and obtain an `SlavePipe` to communicate with the master. + - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, + and passed to a registered callback. + - After receiving the messages, the master device should gather the information and determine to message passed + back to each slave devices. + """ + + def __init__(self, master_callback): + """ + + Args: + master_callback: a callback to be invoked after having collected messages from slave devices. + """ + self._master_callback = master_callback + self._queue = queue.Queue() + self._registry = collections.OrderedDict() + self._activated = False + + def register_slave(self, identifier): + """ + Register an slave device. + + Args: + identifier: an identifier, usually is the device id. + + Returns: a `SlavePipe` object which can be used to communicate with the master device. + + """ + if self._activated: + assert self._queue.empty(), 'Queue is not clean before next initialization.' + self._activated = False + self._registry.clear() + future = FutureResult() + self._registry[identifier] = _MasterRegistry(future) + return SlavePipe(identifier, self._queue, future) + + def run_master(self, master_msg): + """ + Main entry for the master device in each forward pass. + The messages were first collected from each devices (including the master device), and then + an callback will be invoked to compute the message to be sent back to each devices + (including the master device). + + Args: + master_msg: the message that the master want to send to itself. This will be placed as the first + message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. + + Returns: the message to be sent back to the master device. + + """ + self._activated = True + + intermediates = [(0, master_msg)] + for i in range(self.nr_slaves): + intermediates.append(self._queue.get()) + + results = self._master_callback(intermediates) + assert results[0][0] == 0, 'The first result should belongs to the master.' + + for i, res in results: + if i == 0: + continue + self._registry[i].result.put(res) + + for i in range(self.nr_slaves): + assert self._queue.get() is True + + return results[0][1] + + @property + def nr_slaves(self): + return len(self._registry) diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/replicate.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/replicate.py new file mode 100755 index 0000000..b71c7b8 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/replicate.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_numeric_batchnorm.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_numeric_batchnorm.py new file mode 100755 index 0000000..8bd45a9 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_numeric_batchnorm.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# File : test_numeric_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +class NumericTestCase(TorchTestCase): + def testNumericBatchNorm(self): + a = torch.rand(16, 10) + bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) + bn.train() + + a_var1 = Variable(a, requires_grad=True) + b_var1 = bn(a_var1) + loss1 = b_var1.sum() + loss1.backward() + + a_var2 = Variable(a, requires_grad=True) + a_mean2 = a_var2.mean(dim=0, keepdim=True) + a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) + # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) + b_var2 = (a_var2 - a_mean2) / a_std2 + loss2 = b_var2.sum() + loss2.backward() + + self.assertTensorClose(bn.running_mean, a.mean(dim=0)) + self.assertTensorClose(bn.running_var, handy_var(a)) + self.assertTensorClose(a_var1.data, a_var2.data) + self.assertTensorClose(b_var1.data, b_var2.data) + self.assertTensorClose(a_var1.grad, a_var2.grad) + + +if __name__ == '__main__': + unittest.main() diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_sync_batchnorm.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_sync_batchnorm.py new file mode 100755 index 0000000..45bb3c8 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/tests/test_sync_batchnorm.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# File : test_sync_batchnorm.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. + +import unittest + +import torch +import torch.nn as nn +from torch.autograd import Variable + +from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback +from sync_batchnorm.unittest import TorchTestCase + + +def handy_var(a, unbias=True): + n = a.size(0) + asum = a.sum(dim=0) + as_sum = (a ** 2).sum(dim=0) # a square sum + sumvar = as_sum - asum * asum / n + if unbias: + return sumvar / (n - 1) + else: + return sumvar / n + + +def _find_bn(module): + for m in module.modules(): + if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): + return m + + +class SyncTestCase(TorchTestCase): + def _syncParameters(self, bn1, bn2): + bn1.reset_parameters() + bn2.reset_parameters() + if bn1.affine and bn2.affine: + bn2.weight.data.copy_(bn1.weight.data) + bn2.bias.data.copy_(bn1.bias.data) + + def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): + """Check the forward and backward for the customized batch normalization.""" + bn1.train(mode=is_train) + bn2.train(mode=is_train) + + if cuda: + input = input.cuda() + + self._syncParameters(_find_bn(bn1), _find_bn(bn2)) + + input1 = Variable(input, requires_grad=True) + output1 = bn1(input1) + output1.sum().backward() + input2 = Variable(input, requires_grad=True) + output2 = bn2(input2) + output2.sum().backward() + + self.assertTensorClose(input1.data, input2.data) + self.assertTensorClose(output1.data, output2.data) + self.assertTensorClose(input1.grad, input2.grad) + self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) + self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) + + def testSyncBatchNormNormalTrain(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) + + def testSyncBatchNormNormalEval(self): + bn = nn.BatchNorm1d(10) + sync_bn = SynchronizedBatchNorm1d(10) + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) + + def testSyncBatchNormSyncTrain(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) + + def testSyncBatchNormSyncEval(self): + bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) + + def testSyncBatchNorm2DSyncTrain(self): + bn = nn.BatchNorm2d(10) + sync_bn = SynchronizedBatchNorm2d(10) + sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + bn.cuda() + sync_bn.cuda() + + self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) + + +if __name__ == '__main__': + unittest.main() diff --git a/gimp-plugins/EnlightenGAN/lib/nn/modules/unittest.py b/gimp-plugins/EnlightenGAN/lib/nn/modules/unittest.py new file mode 100755 index 0000000..0675c02 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/modules/unittest.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# File : unittest.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import unittest + +import numpy as np +from torch.autograd import Variable + + +def as_numpy(v): + if isinstance(v, Variable): + v = v.data + return v.cpu().numpy() + + +class TorchTestCase(unittest.TestCase): + def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): + npa, npb = as_numpy(a), as_numpy(b) + self.assertTrue( + np.allclose(npa, npb, atol=atol), + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + ) diff --git a/gimp-plugins/EnlightenGAN/lib/nn/parallel/__init__.py b/gimp-plugins/EnlightenGAN/lib/nn/parallel/__init__.py new file mode 100755 index 0000000..9b52f49 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/parallel/__init__.py @@ -0,0 +1 @@ +from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to diff --git a/gimp-plugins/EnlightenGAN/lib/nn/parallel/data_parallel.py b/gimp-plugins/EnlightenGAN/lib/nn/parallel/data_parallel.py new file mode 100755 index 0000000..376fc03 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/lib/nn/parallel/data_parallel.py @@ -0,0 +1,112 @@ +# -*- coding: utf8 -*- + +import torch.cuda as cuda +import torch.nn as nn +import torch +import collections +from torch.nn.parallel._functions import Gather + + +__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] + + +def async_copy_to(obj, dev, main_stream=None): + if torch.is_tensor(obj): + v = obj.cuda(dev, non_blocking=True) + if main_stream is not None: + v.data.record_stream(main_stream) + return v + elif isinstance(obj, collections.Mapping): + return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} + elif isinstance(obj, collections.Sequence): + return [async_copy_to(o, dev, main_stream) for o in obj] + else: + return obj + + +def dict_gather(outputs, target_device, dim=0): + """ + Gathers variables from different GPUs on a specified device + (-1 means the CPU), with dictionary support. + """ + def gather_map(outputs): + out = outputs[0] + if torch.is_tensor(out): + # MJY(20180330) HACK:: force nr_dims > 0 + if out.dim() == 0: + outputs = [o.unsqueeze(0) for o in outputs] + return Gather.apply(target_device, dim, *outputs) + elif out is None: + return None + elif isinstance(out, collections.Mapping): + return {k: gather_map([o[k] for o in outputs]) for k in out} + elif isinstance(out, collections.Sequence): + return type(out)(map(gather_map, zip(*outputs))) + return gather_map(outputs) + + +class DictGatherDataParallel(nn.DataParallel): + def gather(self, outputs, output_device): + return dict_gather(outputs, output_device, dim=self.dim) + + +class UserScatteredDataParallel(DictGatherDataParallel): + def scatter(self, inputs, kwargs, device_ids): + assert len(inputs) == 1 + inputs = inputs[0] + inputs = _async_copy_stream(inputs, device_ids) + inputs = [[i] for i in inputs] + assert len(kwargs) == 0 + kwargs = [{} for _ in range(len(inputs))] + + return inputs, kwargs + + +def user_scattered_collate(batch): + return batch + + +def _async_copy(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + for i, dev in zip(inputs, device_ids): + with cuda.device(dev): + outputs.append(async_copy_to(i, dev)) + + return tuple(outputs) + + +def _async_copy_stream(inputs, device_ids): + nr_devs = len(device_ids) + assert type(inputs) in (tuple, list) + assert len(inputs) == nr_devs + + outputs = [] + streams = [_get_stream(d) for d in device_ids] + for i, dev, stream in zip(inputs, device_ids, streams): + with cuda.device(dev): + main_stream = cuda.current_stream() + with cuda.stream(stream): + outputs.append(async_copy_to(i, dev, main_stream=main_stream)) + main_stream.wait_stream(stream) + + return outputs + + +"""Adapted from: torch/nn/parallel/_functions.py""" +# background streams used for copying +_streams = None + + +def _get_stream(device): + """Gets a background stream for copying between CPU and GPU""" + global _streams + if device == -1: + return None + if _streams is None: + _streams = [None] * cuda.device_count() + if _streams[device] is None: _streams[device] = cuda.Stream(device) + return _streams[device] diff --git a/gimp-plugins/EnlightenGAN/models/__init__.py b/gimp-plugins/EnlightenGAN/models/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/gimp-plugins/EnlightenGAN/models/base_model.py b/gimp-plugins/EnlightenGAN/models/base_model.py new file mode 100755 index 0000000..9729e17 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/models/base_model.py @@ -0,0 +1,61 @@ +import os +import torch + + +class BaseModel(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and not opt.cFlag)else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(device=gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + if torch.cuda.is_available(): + ckpt = torch.load(save_path) + else: + ckpt = torch.load(save_path, map_location=torch.device("cpu")) + ckpt = {key.replace("module.", ""): value for key, value in ckpt.items()} + network.load_state_dict(ckpt) + + def update_learning_rate(): + pass diff --git a/gimp-plugins/EnlightenGAN/models/models.py b/gimp-plugins/EnlightenGAN/models/models.py new file mode 100755 index 0000000..8365f4b --- /dev/null +++ b/gimp-plugins/EnlightenGAN/models/models.py @@ -0,0 +1,38 @@ + +def create_model(opt): + model = None + print(opt.model) + if opt.model == 'cycle_gan': + assert(opt.dataset_mode == 'unaligned') + from .cycle_gan_model import CycleGANModel + model = CycleGANModel() + elif opt.model == 'pix2pix': + assert(opt.dataset_mode == 'pix2pix') + from .pix2pix_model import Pix2PixModel + model = Pix2PixModel() + elif opt.model == 'pair': + # assert(opt.dataset_mode == 'pair') + # from .pair_model import PairModel + from .Unet_L1 import PairModel + model = PairModel() + elif opt.model == 'single': + # assert(opt.dataset_mode == 'unaligned') + from .single_model import SingleModel + model = SingleModel() + elif opt.model == 'temp': + # assert(opt.dataset_mode == 'unaligned') + from .temp_model import TempModel + model = TempModel() + elif opt.model == 'UNIT': + assert(opt.dataset_mode == 'unaligned') + from .unit_model import UNITModel + model = UNITModel() + elif opt.model == 'test': + assert(opt.dataset_mode == 'single') + from .test_model import TestModel + model = TestModel() + else: + raise ValueError("Model [%s] not recognized." % opt.model) + model.initialize(opt) + print("model [%s] was created" % (model.name())) + return model diff --git a/gimp-plugins/EnlightenGAN/models/networks.py b/gimp-plugins/EnlightenGAN/models/networks.py new file mode 100755 index 0000000..7c88961 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/models/networks.py @@ -0,0 +1,1181 @@ +import torch +import os +import math +import torch.nn as nn +from torch.nn import init +import functools +from torch.autograd import Variable +import torch.nn.functional as F +import numpy as np +# from torch.utils.serialization import load_lua +from lib.nn import SynchronizedBatchNorm2d as SynBN2d +############################################################################### +# Functions +############################################################################### + +def pad_tensor(input): + + height_org, width_org = input.shape[2], input.shape[3] + divide = 16 + + if width_org % divide != 0 or height_org % divide != 0: + + width_res = width_org % divide + height_res = height_org % divide + if width_res != 0: + width_div = divide - width_res + pad_left = int(width_div / 2) + pad_right = int(width_div - pad_left) + else: + pad_left = 0 + pad_right = 0 + + if height_res != 0: + height_div = divide - height_res + pad_top = int(height_div / 2) + pad_bottom = int(height_div - pad_top) + else: + pad_top = 0 + pad_bottom = 0 + + padding = nn.ReflectionPad2d((pad_left, pad_right, pad_top, pad_bottom)) + input = padding(input) + else: + pad_left = 0 + pad_right = 0 + pad_top = 0 + pad_bottom = 0 + + height, width = input.data.shape[2], input.data.shape[3] + assert width % divide == 0, 'width cant divided by stride' + assert height % divide == 0, 'height cant divided by stride' + + return input, pad_left, pad_right, pad_top, pad_bottom + +def pad_tensor_back(input, pad_left, pad_right, pad_top, pad_bottom): + height, width = input.shape[2], input.shape[3] + return input[:,:, pad_top: height - pad_bottom, pad_left: width - pad_right] + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def get_norm_layer(norm_type='instance'): + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'synBN': + norm_layer = functools.partial(SynBN2d, affine=True) + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm) + return norm_layer + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[], skip=False, opt=None): + netG = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + # if use_gpu: + # assert(torch.cuda.is_available()) + + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_128': + netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) + elif which_model_netG == 'unet_256': + netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) + elif which_model_netG == 'unet_512': + netG = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids, skip=skip, opt=opt) + elif which_model_netG == 'sid_unet': + netG = Unet(opt, skip) + elif which_model_netG == 'sid_unet_shuffle': + netG = Unet_pixelshuffle(opt, skip) + elif which_model_netG == 'sid_unet_resize': + netG = Unet_resize_conv(opt, skip) + elif which_model_netG == 'DnCNN': + netG = DnCNN(opt, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + if torch.cuda.is_available() and not opt.cFlag: + netG.cuda(device=gpu_ids[0]) + # netG = torch.nn.DataParallel(netG, gpu_ids) + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[], patch=False): + netD = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_norm': + netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_norm_4': + netD = NoNormDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'no_patchgan': + netD = FCDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids, patch=patch) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) + if use_gpu: + netD.cuda(device=gpu_ids[0]) + netD = torch.nn.DataParallel(netD, gpu_ids) + netD.apply(weights_init) + return netD + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + + +class DiscLossWGANGP(): + def __init__(self): + self.LAMBDA = 10 + + def name(self): + return 'DiscLossWGAN-GP' + + def initialize(self, opt, tensor): + # DiscLossLS.initialize(self, opt, tensor) + self.LAMBDA = 10 + + # def get_g_loss(self, net, realA, fakeB): + # # First, G(A) should fake the discriminator + # self.D_fake = net.forward(fakeB) + # return -self.D_fake.mean() + + def calc_gradient_penalty(self, netD, real_data, fake_data): + alpha = torch.rand(1, 1) + alpha = alpha.expand(real_data.size()) + alpha = alpha.cuda() + + interpolates = alpha * real_data + ((1 - alpha) * fake_data) + + interpolates = interpolates.cuda() + interpolates = Variable(interpolates, requires_grad=True) + + disc_interpolates = netD.forward(interpolates) + + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).cuda(), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA + return gradient_penalty + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + +# Define a resnet block +class ResnetBlock(nn.Module): + def __init__(self, dim, padding_type, norm_layer, use_dropout): + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], skip=False, opt=None): + super(UnetGenerator, self).__init__() + self.gpu_ids = gpu_ids + self.opt = opt + # currently support only input_nc == output_nc + assert(input_nc == output_nc) + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True, opt=opt) + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer, opt=opt) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer, opt=opt) + + if skip == True: + skipmodule = SkipModule(unet_block, opt) + self.model = skipmodule + else: + self.model = unet_block + + def forward(self, input): + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + +class SkipModule(nn.Module): + def __init__(self, submodule, opt): + super(SkipModule, self).__init__() + self.submodule = submodule + self.opt = opt + + def forward(self, x): + latent = self.submodule(x) + return self.opt.skip*x + latent, latent + + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, opt=None): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + + downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + stride=2, padding=1) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if opt.use_norm == 0: + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + else: + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([self.model(x), x], 1) + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(NLayerDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + # else: + return self.model(input) + +class NoNormDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): + super(NoNormDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + # if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + # return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + # else: + return self.model(input) + +class FCDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[], patch=False): + super(FCDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + self.use_sigmoid = use_sigmoid + kw = 4 + padw = int(np.ceil((kw-1)/2)) + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + if patch: + self.linear = nn.Linear(7*7,1) + else: + self.linear = nn.Linear(13*13,1) + if use_sigmoid: + self.sigmoid = nn.Sigmoid() + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + batchsize = input.size()[0] + output = self.model(input) + output = output.view(batchsize,-1) + # print(output.size()) + output = self.linear(output) + if self.use_sigmoid: + print("sigmoid") + output = self.sigmoid(output) + return output + + +class Unet_resize_conv(nn.Module): + def __init__(self, opt, skip): + super(Unet_resize_conv, self).__init__() + + self.opt = opt + self.skip = skip + p = 1 + # self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) + if opt.self_attention: + self.conv1_1 = nn.Conv2d(4, 32, 3, padding=p) + # self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) + self.downsample_1 = nn.MaxPool2d(2) + self.downsample_2 = nn.MaxPool2d(2) + self.downsample_3 = nn.MaxPool2d(2) + self.downsample_4 = nn.MaxPool2d(2) + else: + self.conv1_1 = nn.Conv2d(3, 32, 3, padding=p) + self.LReLU1_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn1_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.conv1_2 = nn.Conv2d(32, 32, 3, padding=p) + self.LReLU1_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn1_2 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.max_pool1 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv2_1 = nn.Conv2d(32, 64, 3, padding=p) + self.LReLU2_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn2_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.conv2_2 = nn.Conv2d(64, 64, 3, padding=p) + self.LReLU2_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn2_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.max_pool2 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv3_1 = nn.Conv2d(64, 128, 3, padding=p) + self.LReLU3_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn3_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.conv3_2 = nn.Conv2d(128, 128, 3, padding=p) + self.LReLU3_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn3_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.max_pool3 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv4_1 = nn.Conv2d(128, 256, 3, padding=p) + self.LReLU4_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn4_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.conv4_2 = nn.Conv2d(256, 256, 3, padding=p) + self.LReLU4_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn4_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.max_pool4 = nn.AvgPool2d(2) if self.opt.use_avgpool == 1 else nn.MaxPool2d(2) + + self.conv5_1 = nn.Conv2d(256, 512, 3, padding=p) + self.LReLU5_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn5_1 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=p) + self.LReLU5_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn5_2 = SynBN2d(512) if self.opt.syn_norm else nn.BatchNorm2d(512) + + # self.deconv5 = nn.ConvTranspose2d(512, 256, 2, stride=2) + self.deconv5 = nn.Conv2d(512, 256, 3, padding=p) + self.conv6_1 = nn.Conv2d(512, 256, 3, padding=p) + self.LReLU6_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn6_1 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + self.conv6_2 = nn.Conv2d(256, 256, 3, padding=p) + self.LReLU6_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn6_2 = SynBN2d(256) if self.opt.syn_norm else nn.BatchNorm2d(256) + + # self.deconv6 = nn.ConvTranspose2d(256, 128, 2, stride=2) + self.deconv6 = nn.Conv2d(256, 128, 3, padding=p) + self.conv7_1 = nn.Conv2d(256, 128, 3, padding=p) + self.LReLU7_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn7_1 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + self.conv7_2 = nn.Conv2d(128, 128, 3, padding=p) + self.LReLU7_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn7_2 = SynBN2d(128) if self.opt.syn_norm else nn.BatchNorm2d(128) + + # self.deconv7 = nn.ConvTranspose2d(128, 64, 2, stride=2) + self.deconv7 = nn.Conv2d(128, 64, 3, padding=p) + self.conv8_1 = nn.Conv2d(128, 64, 3, padding=p) + self.LReLU8_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn8_1 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + self.conv8_2 = nn.Conv2d(64, 64, 3, padding=p) + self.LReLU8_2 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn8_2 = SynBN2d(64) if self.opt.syn_norm else nn.BatchNorm2d(64) + + # self.deconv8 = nn.ConvTranspose2d(64, 32, 2, stride=2) + self.deconv8 = nn.Conv2d(64, 32, 3, padding=p) + self.conv9_1 = nn.Conv2d(64, 32, 3, padding=p) + self.LReLU9_1 = nn.LeakyReLU(0.2, inplace=True) + if self.opt.use_norm == 1: + self.bn9_1 = SynBN2d(32) if self.opt.syn_norm else nn.BatchNorm2d(32) + self.conv9_2 = nn.Conv2d(32, 32, 3, padding=p) + self.LReLU9_2 = nn.LeakyReLU(0.2, inplace=True) + + self.conv10 = nn.Conv2d(32, 3, 1) + if self.opt.tanh: + self.tanh = nn.Tanh() + + def depth_to_space(self, input, block_size): + block_size_sq = block_size*block_size + output = input.permute(0, 2, 3, 1) + (batch_size, d_height, d_width, d_depth) = output.size() + s_depth = int(d_depth / block_size_sq) + s_width = int(d_width * block_size) + s_height = int(d_height * block_size) + t_1 = output.resize(batch_size, d_height, d_width, block_size_sq, s_depth) + spl = t_1.split(block_size, 3) + stack = [t_t.resize(batch_size, d_height, s_width, s_depth) for t_t in spl] + output = torch.stack(stack,0).transpose(0,1).permute(0,2,1,3,4).resize(batch_size, s_height, s_width, s_depth) + output = output.permute(0, 3, 1, 2) + return output + + def forward(self, input, gray): + flag = 0 + if input.size()[3] > 2200: + avg = nn.AvgPool2d(2) + input = avg(input) + gray = avg(gray) + flag = 1 + # pass + input, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(input) + gray, pad_left, pad_right, pad_top, pad_bottom = pad_tensor(gray) + if self.opt.self_attention: + gray_2 = self.downsample_1(gray) + gray_3 = self.downsample_2(gray_2) + gray_4 = self.downsample_3(gray_3) + gray_5 = self.downsample_4(gray_4) + if self.opt.use_norm == 1: + if self.opt.self_attention: + x = self.bn1_1(self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1)))) + # x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) + else: + x = self.bn1_1(self.LReLU1_1(self.conv1_1(input))) + conv1 = self.bn1_2(self.LReLU1_2(self.conv1_2(x))) + x = self.max_pool1(conv1) + + x = self.bn2_1(self.LReLU2_1(self.conv2_1(x))) + conv2 = self.bn2_2(self.LReLU2_2(self.conv2_2(x))) + x = self.max_pool2(conv2) + + x = self.bn3_1(self.LReLU3_1(self.conv3_1(x))) + conv3 = self.bn3_2(self.LReLU3_2(self.conv3_2(x))) + x = self.max_pool3(conv3) + + x = self.bn4_1(self.LReLU4_1(self.conv4_1(x))) + conv4 = self.bn4_2(self.LReLU4_2(self.conv4_2(x))) + x = self.max_pool4(conv4) + + x = self.bn5_1(self.LReLU5_1(self.conv5_1(x))) + x = x*gray_5 if self.opt.self_attention else x + conv5 = self.bn5_2(self.LReLU5_2(self.conv5_2(x))) + + conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') + conv4 = conv4*gray_4 if self.opt.self_attention else conv4 + up6 = torch.cat([self.deconv5(conv5), conv4], 1) + x = self.bn6_1(self.LReLU6_1(self.conv6_1(up6))) + conv6 = self.bn6_2(self.LReLU6_2(self.conv6_2(x))) + + conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') + conv3 = conv3*gray_3 if self.opt.self_attention else conv3 + up7 = torch.cat([self.deconv6(conv6), conv3], 1) + x = self.bn7_1(self.LReLU7_1(self.conv7_1(up7))) + conv7 = self.bn7_2(self.LReLU7_2(self.conv7_2(x))) + + conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') + conv2 = conv2*gray_2 if self.opt.self_attention else conv2 + up8 = torch.cat([self.deconv7(conv7), conv2], 1) + x = self.bn8_1(self.LReLU8_1(self.conv8_1(up8))) + conv8 = self.bn8_2(self.LReLU8_2(self.conv8_2(x))) + + conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') + conv1 = conv1*gray if self.opt.self_attention else conv1 + up9 = torch.cat([self.deconv8(conv8), conv1], 1) + x = self.bn9_1(self.LReLU9_1(self.conv9_1(up9))) + conv9 = self.LReLU9_2(self.conv9_2(x)) + + latent = self.conv10(conv9) + + if self.opt.times_residual: + latent = latent*gray + + # output = self.depth_to_space(conv10, 2) + if self.opt.tanh: + latent = self.tanh(latent) + if self.skip: + if self.opt.linear_add: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) + output = latent + input*self.opt.skip + output = output*2 - 1 + else: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + output = latent + input*self.opt.skip + else: + output = latent + + if self.opt.linear: + output = output/torch.max(torch.abs(output)) + + + elif self.opt.use_norm == 0: + if self.opt.self_attention: + x = self.LReLU1_1(self.conv1_1(torch.cat((input, gray), 1))) + else: + x = self.LReLU1_1(self.conv1_1(input)) + conv1 = self.LReLU1_2(self.conv1_2(x)) + x = self.max_pool1(conv1) + + x = self.LReLU2_1(self.conv2_1(x)) + conv2 = self.LReLU2_2(self.conv2_2(x)) + x = self.max_pool2(conv2) + + x = self.LReLU3_1(self.conv3_1(x)) + conv3 = self.LReLU3_2(self.conv3_2(x)) + x = self.max_pool3(conv3) + + x = self.LReLU4_1(self.conv4_1(x)) + conv4 = self.LReLU4_2(self.conv4_2(x)) + x = self.max_pool4(conv4) + + x = self.LReLU5_1(self.conv5_1(x)) + x = x*gray_5 if self.opt.self_attention else x + conv5 = self.LReLU5_2(self.conv5_2(x)) + + conv5 = F.upsample(conv5, scale_factor=2, mode='bilinear') + conv4 = conv4*gray_4 if self.opt.self_attention else conv4 + up6 = torch.cat([self.deconv5(conv5), conv4], 1) + x = self.LReLU6_1(self.conv6_1(up6)) + conv6 = self.LReLU6_2(self.conv6_2(x)) + + conv6 = F.upsample(conv6, scale_factor=2, mode='bilinear') + conv3 = conv3*gray_3 if self.opt.self_attention else conv3 + up7 = torch.cat([self.deconv6(conv6), conv3], 1) + x = self.LReLU7_1(self.conv7_1(up7)) + conv7 = self.LReLU7_2(self.conv7_2(x)) + + conv7 = F.upsample(conv7, scale_factor=2, mode='bilinear') + conv2 = conv2*gray_2 if self.opt.self_attention else conv2 + up8 = torch.cat([self.deconv7(conv7), conv2], 1) + x = self.LReLU8_1(self.conv8_1(up8)) + conv8 = self.LReLU8_2(self.conv8_2(x)) + + conv8 = F.upsample(conv8, scale_factor=2, mode='bilinear') + conv1 = conv1*gray if self.opt.self_attention else conv1 + up9 = torch.cat([self.deconv8(conv8), conv1], 1) + x = self.LReLU9_1(self.conv9_1(up9)) + conv9 = self.LReLU9_2(self.conv9_2(x)) + + latent = self.conv10(conv9) + + if self.opt.times_residual: + latent = latent*gray + + if self.opt.tanh: + latent = self.tanh(latent) + if self.skip: + if self.opt.linear_add: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + input = (input - torch.min(input))/(torch.max(input) - torch.min(input)) + output = latent + input*self.opt.skip + output = output*2 - 1 + else: + if self.opt.latent_threshold: + latent = F.relu(latent) + elif self.opt.latent_norm: + latent = (latent - torch.min(latent))/(torch.max(latent)-torch.min(latent)) + output = latent + input*self.opt.skip + else: + output = latent + + if self.opt.linear: + output = output/torch.max(torch.abs(output)) + + output = pad_tensor_back(output, pad_left, pad_right, pad_top, pad_bottom) + latent = pad_tensor_back(latent, pad_left, pad_right, pad_top, pad_bottom) + gray = pad_tensor_back(gray, pad_left, pad_right, pad_top, pad_bottom) + if flag == 1: + output = F.upsample(output, scale_factor=2, mode='bilinear') + gray = F.upsample(gray, scale_factor=2, mode='bilinear') + if self.skip: + return output, latent + else: + return output + +class DnCNN(nn.Module): + def __init__(self, opt=None, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3): + super(DnCNN, self).__init__() + kernel_size = 3 + padding = 1 + layers = [] + + layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True)) + layers.append(nn.ReLU(inplace=True)) + for _ in range(depth-2): + layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False)) + layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95)) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)) + self.dncnn = nn.Sequential(*layers) + self._initialize_weights() + + def forward(self, x): + y = x + out = self.dncnn(x) + return y+out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.orthogonal_(m.weight) + print('init weight') + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + +class Vgg16(nn.Module): + def __init__(self): + super(Vgg16, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + def forward(self, X, opt): + h = F.relu(self.conv1_1(X), inplace=True) + h = F.relu(self.conv1_2(h), inplace=True) + # relu1_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv2_1(h), inplace=True) + h = F.relu(self.conv2_2(h), inplace=True) + # relu2_2 = h + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv3_1(h), inplace=True) + h = F.relu(self.conv3_2(h), inplace=True) + h = F.relu(self.conv3_3(h), inplace=True) + # relu3_3 = h + if opt.vgg_choose != "no_maxpool": + h = F.max_pool2d(h, kernel_size=2, stride=2) + + h = F.relu(self.conv4_1(h), inplace=True) + relu4_1 = h + h = F.relu(self.conv4_2(h), inplace=True) + relu4_2 = h + conv4_3 = self.conv4_3(h) + h = F.relu(conv4_3, inplace=True) + relu4_3 = h + + if opt.vgg_choose != "no_maxpool": + if opt.vgg_maxpooling: + h = F.max_pool2d(h, kernel_size=2, stride=2) + + relu5_1 = F.relu(self.conv5_1(h), inplace=True) + relu5_2 = F.relu(self.conv5_2(relu5_1), inplace=True) + conv5_3 = self.conv5_3(relu5_2) + h = F.relu(conv5_3, inplace=True) + relu5_3 = h + if opt.vgg_choose == "conv4_3": + return conv4_3 + elif opt.vgg_choose == "relu4_2": + return relu4_2 + elif opt.vgg_choose == "relu4_1": + return relu4_1 + elif opt.vgg_choose == "relu4_3": + return relu4_3 + elif opt.vgg_choose == "conv5_3": + return conv5_3 + elif opt.vgg_choose == "relu5_1": + return relu5_1 + elif opt.vgg_choose == "relu5_2": + return relu5_2 + elif opt.vgg_choose == "relu5_3" or "maxpool": + return relu5_3 + +def vgg_preprocess(batch, opt): + tensortype = type(batch.data) + (r, g, b) = torch.chunk(batch, 3, dim = 1) + batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR + batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] + if opt.vgg_mean: + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + batch = batch.sub(Variable(mean)) # subtract mean + return batch + +class PerceptualLoss(nn.Module): + def __init__(self, opt): + super(PerceptualLoss, self).__init__() + self.opt = opt + self.instancenorm = nn.InstanceNorm2d(512, affine=False) + + def compute_vgg_loss(self, vgg, img, target): + img_vgg = vgg_preprocess(img, self.opt) + target_vgg = vgg_preprocess(target, self.opt) + img_fea = vgg(img_vgg, self.opt) + target_fea = vgg(target_vgg, self.opt) + if self.opt.no_vgg_instance: + return torch.mean((img_fea - target_fea) ** 2) + else: + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) + +def load_vgg16(model_dir, gpu_ids): + """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ + if not os.path.exists(model_dir): + os.mkdir(model_dir) + # if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): + # if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): + # os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) + # vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) + # vgg = Vgg16() + # for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + # dst.data[:] = src + # torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) + vgg = Vgg16() + # vgg.cuda() + vgg.cuda(device=gpu_ids[0]) + vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) + vgg = torch.nn.DataParallel(vgg, gpu_ids) + return vgg + + + +class FCN32s(nn.Module): + def __init__(self, n_class=21): + super(FCN32s, self).__init__() + # conv1 + self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100) + self.relu1_1 = nn.ReLU(inplace=True) + self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) + self.relu1_2 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 + + # conv2 + self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) + self.relu2_1 = nn.ReLU(inplace=True) + self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) + self.relu2_2 = nn.ReLU(inplace=True) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 + + # conv3 + self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) + self.relu3_1 = nn.ReLU(inplace=True) + self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) + self.relu3_2 = nn.ReLU(inplace=True) + self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) + self.relu3_3 = nn.ReLU(inplace=True) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 + + # conv4 + self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) + self.relu4_1 = nn.ReLU(inplace=True) + self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) + self.relu4_2 = nn.ReLU(inplace=True) + self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) + self.relu4_3 = nn.ReLU(inplace=True) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 + + # conv5 + self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_1 = nn.ReLU(inplace=True) + self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_2 = nn.ReLU(inplace=True) + self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) + self.relu5_3 = nn.ReLU(inplace=True) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/32 + + # fc6 + self.fc6 = nn.Conv2d(512, 4096, 7) + self.relu6 = nn.ReLU(inplace=True) + self.drop6 = nn.Dropout2d() + + # fc7 + self.fc7 = nn.Conv2d(4096, 4096, 1) + self.relu7 = nn.ReLU(inplace=True) + self.drop7 = nn.Dropout2d() + + self.score_fr = nn.Conv2d(4096, n_class, 1) + self.upscore = nn.ConvTranspose2d(n_class, n_class, 64, stride=32, + bias=False) + + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.zero_() + if m.bias is not None: + m.bias.data.zero_() + if isinstance(m, nn.ConvTranspose2d): + assert m.kernel_size[0] == m.kernel_size[1] + initial_weight = get_upsampling_weight( + m.in_channels, m.out_channels, m.kernel_size[0]) + m.weight.data.copy_(initial_weight) + + def forward(self, x): + h = x + h = self.relu1_1(self.conv1_1(h)) + h = self.relu1_2(self.conv1_2(h)) + h = self.pool1(h) + + h = self.relu2_1(self.conv2_1(h)) + h = self.relu2_2(self.conv2_2(h)) + h = self.pool2(h) + + h = self.relu3_1(self.conv3_1(h)) + h = self.relu3_2(self.conv3_2(h)) + h = self.relu3_3(self.conv3_3(h)) + h = self.pool3(h) + + h = self.relu4_1(self.conv4_1(h)) + h = self.relu4_2(self.conv4_2(h)) + h = self.relu4_3(self.conv4_3(h)) + h = self.pool4(h) + + h = self.relu5_1(self.conv5_1(h)) + h = self.relu5_2(self.conv5_2(h)) + h = self.relu5_3(self.conv5_3(h)) + h = self.pool5(h) + + h = self.relu6(self.fc6(h)) + h = self.drop6(h) + + h = self.relu7(self.fc7(h)) + h = self.drop7(h) + + h = self.score_fr(h) + + h = self.upscore(h) + h = h[:, :, 19:19 + x.size()[2], 19:19 + x.size()[3]].contiguous() + return h + +def load_fcn(model_dir): + fcn = FCN32s() + fcn.load_state_dict(torch.load(os.path.join(model_dir, 'fcn32s_from_caffe.pth'))) + fcn.cuda() + return fcn + +class SemanticLoss(nn.Module): + def __init__(self, opt): + super(SemanticLoss, self).__init__() + self.opt = opt + self.instancenorm = nn.InstanceNorm2d(21, affine=False) + + def compute_fcn_loss(self, fcn, img, target): + img_fcn = vgg_preprocess(img, self.opt) + target_fcn = vgg_preprocess(target, self.opt) + img_fea = fcn(img_fcn) + target_fea = fcn(target_fcn) + return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) diff --git a/gimp-plugins/EnlightenGAN/models/single_model.py b/gimp-plugins/EnlightenGAN/models/single_model.py new file mode 100755 index 0000000..b1042c8 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/models/single_model.py @@ -0,0 +1,496 @@ +import numpy as np +import torch +from torch import nn +import os +from collections import OrderedDict +from torch.autograd import Variable +import util.util as util +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +import random +from . import networks +import sys + + +class SingleModel(BaseModel): + def name(self): + return 'SingleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.opt = opt + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + self.input_img = self.Tensor(nb, opt.input_nc, size, size) + self.input_A_gray = self.Tensor(nb, 1, size, size) + + if opt.vgg > 0: + self.vgg_loss = networks.PerceptualLoss(opt) + if self.opt.IN_vgg: + self.vgg_patch_loss = networks.PerceptualLoss(opt) + self.vgg_patch_loss.cuda() + self.vgg_loss.cuda() + self.vgg = networks.load_vgg16("./model", self.gpu_ids) + self.vgg.eval() + for param in self.vgg.parameters(): + param.requires_grad = False + elif opt.fcn > 0: + self.fcn_loss = networks.SemanticLoss(opt) + self.fcn_loss.cuda() + self.fcn = networks.load_fcn("./model") + self.fcn.eval() + for param in self.fcn.parameters(): + param.requires_grad = False + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + skip = True if opt.skip > 0 else False + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt) + # self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + # opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, False) + if self.opt.patchD: + self.netD_P = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_patchD, opt.norm, use_sigmoid, self.gpu_ids, True) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + # self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + if self.opt.patchD: + self.load_network(self.netD_P, 'D_P', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + # self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + if opt.use_wgan: + self.criterionGAN = networks.DiscLossWGANGP() + else: + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + if opt.use_mse: + self.criterionCycle = torch.nn.MSELoss() + else: + self.criterionCycle = torch.nn.L1Loss() + self.criterionL1 = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + if self.opt.patchD: + self.optimizer_D_P = torch.optim.Adam(self.netD_P.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + # networks.print_network(self.netG_B) + if self.isTrain: + networks.print_network(self.netD_A) + if self.opt.patchD: + networks.print_network(self.netD_P) + # networks.print_network(self.netD_B) + if opt.isTrain: + self.netG_A.train() + # self.netG_B.train() + else: + self.netG_A.eval() + # self.netG_B.eval() + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + input_img = input['input_img'] + input_A_gray = input['A_gray'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.input_img.resize_(input_img.size()).copy_(input_img) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + + + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + + + def predict(self): + self.real_A = Variable(self.input_A, volatile=True) + self.real_A_gray = Variable(self.input_A_gray, volatile=True) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + # print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:]) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray) + # self.rec_A = self.netG_B.forward(self.fake_B) + + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + A_gray = util.atten2im(self.real_A_gray.data) + # rec_A = util.tensor2im(self.rec_A.data) + # if self.opt.skip == 1: + # latent_real_A = util.tensor2im(self.latent_real_A.data) + # latent_show = util.latent2im(self.latent_real_A.data) + # max_image = util.max2im(self.fake_B.data, self.latent_real_A.data) + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + # ('latent_show', latent_show), ('max_image', max_image), ('A_gray', A_gray)]) + # else: + # return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + # return OrderedDict([('fake_B', fake_B)]) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake, use_ragan): + # Real + pred_real = netD.forward(real) + pred_fake = netD.forward(fake.detach()) + if self.opt.use_wgan: + loss_D_real = pred_real.mean() + loss_D_fake = pred_fake.mean() + loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD, + real.data, fake.data) + elif self.opt.use_ragan and use_ragan: + loss_D = (self.criterionGAN(pred_real - torch.mean(pred_fake), True) + + self.criterionGAN(pred_fake - torch.mean(pred_real), False)) / 2 + else: + loss_D_real = self.criterionGAN(pred_real, True) + loss_D_fake = self.criterionGAN(pred_fake, False) + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + fake_B = self.fake_B + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, True) + self.loss_D_A.backward() + + def backward_D_P(self): + if self.opt.hybrid_loss: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, False) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], False) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + else: + loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, True) + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], True) + self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1) + else: + self.loss_D_P = loss_D_P + if self.opt.D_P_times2: + self.loss_D_P = self.loss_D_P*2 + self.loss_D_P.backward() + + # def backward_D_B(self): + # fake_A = self.fake_A_pool.query(self.fake_A) + # self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + self.real_A_gray = Variable(self.input_A_gray) + self.real_img = Variable(self.input_img) + if self.opt.noise > 0: + self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.)) + self.real_A = self.real_A + self.noise + if self.opt.input_linear: + self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A)) + if self.opt.skip == 1: + self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_img, self.real_A_gray) + else: + self.fake_B = self.netG_A.forward(self.real_img, self.real_A_gray) + if self.opt.patchD: + w = self.real_A.size(3) + h = self.real_A.size(2) + w_offset = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.patchSize - 1)) + + self.fake_patch = self.fake_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.real_patch = self.real_B[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + self.input_patch = self.real_A[:,:, h_offset:h_offset + self.opt.patchSize, + w_offset:w_offset + self.opt.patchSize] + if self.opt.patchD_3 > 0: + self.fake_patch_1 = [] + self.real_patch_1 = [] + self.input_patch_1 = [] + w = self.real_A.size(3) + h = self.real_A.size(2) + for i in range(self.opt.patchD_3): + w_offset_1 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + h_offset_1 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + self.fake_patch_1.append(self.fake_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.real_patch_1.append(self.real_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + self.input_patch_1.append(self.real_A[:,:, h_offset_1:h_offset_1 + self.opt.patchSize, + w_offset_1:w_offset_1 + self.opt.patchSize]) + + # w_offset_2 = random.randint(0, max(0, w - self.opt.patchSize - 1)) + # h_offset_2 = random.randint(0, max(0, h - self.opt.patchSize - 1)) + # self.fake_patch_2 = self.fake_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + # self.real_patch_2 = self.real_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + # self.input_patch_2 = self.real_A[:,:, h_offset_2:h_offset_2 + self.opt.patchSize, + # w_offset_2:w_offset_2 + self.opt.patchSize] + + def backward_G(self, epoch): + pred_fake = self.netD_A.forward(self.fake_B) + if self.opt.use_wgan: + self.loss_G_A = -pred_fake.mean() + elif self.opt.use_ragan: + pred_real = self.netD_A.forward(self.real_B) + + self.loss_G_A = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) + + self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2 + + else: + self.loss_G_A = self.criterionGAN(pred_fake, True) + + loss_G_A = 0 + if self.opt.patchD: + pred_fake_patch = self.netD_P.forward(self.fake_patch) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch, True) + else: + pred_real_patch = self.netD_P.forward(self.real_patch) + + loss_G_A += (self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) + + self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2 + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i]) + if self.opt.hybrid_loss: + loss_G_A += self.criterionGAN(pred_fake_patch_1, True) + else: + pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i]) + + loss_G_A += (self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) + + self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2 + + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1) + else: + self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2 + else: + if not self.opt.D_P_times2: + self.loss_G_A += loss_G_A + else: + self.loss_G_A += loss_G_A*2 + + if epoch < 0: + vgg_w = 0 + else: + vgg_w = 1 + if self.opt.vgg > 0: + self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0 + if self.opt.patch_vgg: + if not self.opt.IN_vgg: + loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + else: + loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch, self.input_patch) * self.opt.vgg + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + if not self.opt.IN_vgg: + loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + else: + loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg + self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_vgg_b += loss_vgg_patch + self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w + elif self.opt.fcn > 0: + self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn, + self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0 + if self.opt.patchD: + loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch, self.input_patch) * self.opt.fcn + if self.opt.patchD_3 > 0: + for i in range(self.opt.patchD_3): + loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn, + self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn + self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1) + else: + self.loss_fcn_b += loss_fcn_patch + self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w + # self.loss_G = self.L1_AB + self.L1_BA + self.loss_G.backward() + + + # def optimize_parameters(self, epoch): + # # forward + # self.forward() + # # G_A and G_B + # self.optimizer_G.zero_grad() + # self.backward_G(epoch) + # self.optimizer_G.step() + # # D_A + # self.optimizer_D_A.zero_grad() + # self.backward_D_A() + # self.optimizer_D_A.step() + # if self.opt.patchD: + # self.forward() + # self.optimizer_D_P.zero_grad() + # self.backward_D_P() + # self.optimizer_D_P.step() + # D_B + # self.optimizer_D_B.zero_grad() + # self.backward_D_B() + # self.optimizer_D_B.step() + def optimize_parameters(self, epoch): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G(epoch) + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + if not self.opt.patchD: + self.optimizer_D_A.step() + else: + # self.forward() + self.optimizer_D_P.zero_grad() + self.backward_D_P() + self.optimizer_D_A.step() + self.optimizer_D_P.step() + + + def get_current_errors(self, epoch): + D_A = self.loss_D_A.data[0] + D_P = self.loss_D_P.data[0] if self.opt.patchD else 0 + G_A = self.loss_G_A.data[0] + if self.opt.vgg > 0: + vgg = self.loss_vgg_b.data[0]/self.opt.vgg if self.opt.vgg > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("vgg", vgg), ("D_P", D_P)]) + elif self.opt.fcn > 0: + fcn = self.loss_fcn_b.data[0]/self.opt.fcn if self.opt.fcn > 0 else 0 + return OrderedDict([('D_A', D_A), ('G_A', G_A), ("fcn", fcn), ("D_P", D_P)]) + + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + if self.opt.skip > 0: + latent_real_A = util.tensor2im(self.latent_real_A.data) + latent_show = util.latent2im(self.latent_real_A.data) + if self.opt.patchD: + fake_patch = util.tensor2im(self.fake_patch.data) + real_patch = util.tensor2im(self.real_patch.data) + if self.opt.patch_vgg: + input_patch = util.tensor2im(self.input_patch.data) + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('input_patch', input_patch)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('input_patch', input_patch), ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch), + ('fake_patch', fake_patch), ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A), + ('latent_show', latent_show), ('real_B', real_B)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + ('latent_real_A', latent_real_A), ('latent_show', latent_show), + ('self_attention', self_attention)]) + else: + if not self.opt.self_attention: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + else: + self_attention = util.atten2im(self.real_A_gray.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B), + ('self_attention', self_attention)]) + + def save(self, label): + self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) + self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) + if self.opt.patchD: + self.save_network(self.netD_P, 'D_P', label, self.gpu_ids) + # self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) + # self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) + + def update_learning_rate(self): + + if self.opt.new_lr: + lr = self.old_lr/2 + else: + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + if self.opt.patchD: + for param_group in self.optimizer_D_P.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/gimp-plugins/EnlightenGAN/util/__init__.py b/gimp-plugins/EnlightenGAN/util/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/gimp-plugins/EnlightenGAN/util/image_pool.py b/gimp-plugins/EnlightenGAN/util/image_pool.py new file mode 100755 index 0000000..152ef5b --- /dev/null +++ b/gimp-plugins/EnlightenGAN/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/gimp-plugins/EnlightenGAN/util/util.py b/gimp-plugins/EnlightenGAN/util/util.py new file mode 100755 index 0000000..5d499e1 --- /dev/null +++ b/gimp-plugins/EnlightenGAN/util/util.py @@ -0,0 +1,182 @@ +# from __future__ import print_function +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import torch +import os +import collections +from torch.optim import lr_scheduler +import torch.nn.init as init + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + image_numpy = np.maximum(image_numpy, 0) + image_numpy = np.minimum(image_numpy, 255) + return image_numpy.astype(imtype) + +def atten2im(image_tensor, imtype=np.uint8): + image_tensor = image_tensor[0] + image_tensor = torch.cat((image_tensor, image_tensor, image_tensor), 0) + image_numpy = image_tensor.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 + image_numpy = image_numpy/(image_numpy.max()/255.0) + return image_numpy.astype(imtype) + +def latent2im(image_tensor, imtype=np.uint8): + # image_tensor = (image_tensor - torch.min(image_tensor))/(torch.max(image_tensor)-torch.min(image_tensor)) + image_numpy = image_tensor[0].cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0 + image_numpy = np.maximum(image_numpy, 0) + image_numpy = np.minimum(image_numpy, 255) + return image_numpy.astype(imtype) + +def max2im(image_1, image_2, imtype=np.uint8): + image_1 = image_1[0].cpu().float().numpy() + image_2 = image_2[0].cpu().float().numpy() + image_1 = (np.transpose(image_1, (1, 2, 0)) + 1) / 2.0 * 255.0 + image_2 = (np.transpose(image_2, (1, 2, 0))) * 255.0 + output = np.maximum(image_1, image_2) + output = np.maximum(output, 0) + output = np.minimum(output, 255) + return output.astype(imtype) + +def variable2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].data.cpu().float().numpy() + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def info(object, spacing=10, collapse=1): + """Print methods and doc strings. + Takes module, class, list, dictionary, or string.""" + methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)] + processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s) + print( "\n".join(["%s %s" % + (method.ljust(spacing), + processFunc(str(getattr(object, method).__doc__))) + for method in methodList]) ) + +def varname(p): + for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]: + m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line) + if m: + return m.group(1) + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + +def get_model_list(dirname, key): + if os.path.exists(dirname) is False: + return None + gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if + os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] + if gen_models is None: + return None + gen_models.sort() + last_model_name = gen_models[-1] + return last_model_name + + +def load_vgg16(model_dir): + """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): + if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): + os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) + vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) + vgg = Vgg16() + for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): + dst.data[:] = src + torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) + vgg = Vgg16() + vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) + return vgg + + +def vgg_preprocess(batch): + tensortype = type(batch.data) + (r, g, b) = torch.chunk(batch, 3, dim = 1) + batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR + batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] + mean = tensortype(batch.data.size()) + mean[:, 0, :, :] = 103.939 + mean[:, 1, :, :] = 116.779 + mean[:, 2, :, :] = 123.680 + batch = batch.sub(Variable(mean)) # subtract mean + return batch + + +def get_scheduler(optimizer, hyperparameters, iterations=-1): + if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': + scheduler = None # constant scheduler + elif hyperparameters['lr_policy'] == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], + gamma=hyperparameters['gamma'], last_epoch=iterations) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) + return scheduler + + +def weights_init(init_type='gaussian'): + def init_fun(m): + classname = m.__class__.__name__ + if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): + # print m.__class__.__name__ + if init_type == 'gaussian': + init.normal(m.weight.data, 0.0, 0.02) + elif init_type == 'xavier': + init.xavier_normal(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'kaiming': + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal(m.weight.data, gain=math.sqrt(2)) + elif init_type == 'default': + pass + else: + assert 0, "Unsupported initialization: {}".format(init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant(m.bias.data, 0.0) + + return init_fun \ No newline at end of file diff --git a/gimp-plugins/enlighten.py b/gimp-plugins/enlighten.py new file mode 100755 index 0000000..a13d918 --- /dev/null +++ b/gimp-plugins/enlighten.py @@ -0,0 +1,107 @@ +import os + +baseLoc = os.path.dirname(os.path.realpath(__file__)) + '/' + +from gimpfu import * +import sys + +sys.path.extend([baseLoc + 'gimpenv/lib/python2.7', baseLoc + 'gimpenv/lib/python2.7/site-packages', + baseLoc + 'gimpenv/lib/python2.7/site-packages/setuptools', baseLoc + 'EnlightenGAN']) + +from argparse import Namespace +import cv2 +import numpy as np +import torch +from models.models import create_model +from data.base_dataset import get_transform + + + + +def getEnlighten(input_image,cFlag): + + opt = Namespace(D_P_times2=False, IN_vgg=False, aspect_ratio=1.0, batchSize=1, + checkpoints_dir=baseLoc+'weights/', dataroot='test_dataset', + dataset_mode='unaligned', display_id=1, display_port=8097, + display_single_pane_ncols=0, display_winsize=256, fcn=0, + fineSize=256, gpu_ids=[0], high_times=400, how_many=50, + hybrid_loss=False, identity=0.0, input_linear=False, input_nc=3, + instance_norm=0.0, isTrain=False, l1=10.0, lambda_A=10.0, + lambda_B=10.0, latent_norm=False, latent_threshold=False, + lighten=False, linear=False, linear_add=False, loadSize=286, + low_times=200, max_dataset_size='inf', model='single', + multiply=False, nThreads=1, n_layers_D=3, n_layers_patchD=3, + name='enlightening', ndf=64, new_lr=False, ngf=64, no_dropout=True, + no_flip=True, no_vgg_instance=False, noise=0, norm='instance', + norm_attention=False, ntest='inf', output_nc=3, patchD=False, + patchD_3=0, patchSize=64, patch_vgg=False, phase='test', + resize_or_crop='no', results_dir='./results/', self_attention=True, + serial_batches=True, skip=1.0, syn_norm=False, tanh=False, + times_residual=True, use_avgpool=0, use_mse=False, use_norm=1.0, + use_ragan=False, use_wgan=0.0, vary=1, vgg=0, vgg_choose='relu5_3', + vgg_maxpooling=False, vgg_mean=False, which_direction='AtoB', + which_epoch='200', which_model_netD='basic', which_model_netG='sid_unet_resize', cFlag=cFlag) + + im = cv2.cvtColor(input_image,cv2.COLOR_RGB2BGR) + transform = get_transform(opt) + A_img = transform(im) + r, g, b = A_img[0] + 1, A_img[1] + 1, A_img[2] + 1 + A_gray = 1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2. + A_gray = torch.unsqueeze(A_gray, 0) + data = {'A': A_img.unsqueeze(0), 'B': A_img.unsqueeze(0), 'A_gray': A_gray.unsqueeze(0), 'input_img': A_img.unsqueeze(0), 'A_paths': 'A_path', 'B_paths': 'B_path'} + + model = create_model(opt) + model.set_input(data) + visuals = model.predict() + out = visuals['fake_B'].astype(np.uint8) + out = cv2.cvtColor(out,cv2.COLOR_BGR2RGB) + # cv2.imwrite("/Users/kritiksoman/PycharmProjects/new/out.png", out) + return out + + +def channelData(layer): # convert gimp image to numpy + region = layer.get_pixel_rgn(0, 0, layer.width, layer.height) + pixChars = region[:, :] # Take whole layer + bpp = region.bpp + # return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp) + return np.frombuffer(pixChars, dtype=np.uint8).reshape(layer.height, layer.width, bpp) + + +def createResultLayer(image, name, result): + rlBytes = np.uint8(result).tobytes(); + rl = gimp.Layer(image, name, image.width, image.height, 0, 100, NORMAL_MODE) + region = rl.get_pixel_rgn(0, 0, rl.width, rl.height, True) + region[:, :] = rlBytes + image.add_layer(rl, 0) + gimp.displays_flush() + + +def Enlighten(img, layer,cFlag): + if torch.cuda.is_available() and not cFlag: + gimp.progress_init("(Using GPU) Enlighten " + layer.name + "...") + else: + gimp.progress_init("(Using CPU) Enlighten " + layer.name + "...") + imgmat = channelData(layer) + if imgmat.shape[2] == 4: # get rid of alpha channel + imgmat = imgmat[:,:,0:3] + cpy = getEnlighten(imgmat,cFlag) + createResultLayer(img, 'new_output', cpy) + + +register( + "enlighten", + "enlighten", + "Enlighten image based on deep learning.", + "Kritik Soman", + "Your", + "2020", + "enlighten...", + "*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc. + [(PF_IMAGE, "image", "Input image", None), + (PF_DRAWABLE, "drawable", "Input drawable", None), + (PF_BOOL, "fcpu", "Force CPU", False) + ], + [], + Enlighten, menu="/Layer/GIML-ML") + +main() \ No newline at end of file diff --git a/gimp-plugins/syncWeights.py b/gimp-plugins/syncWeights.py index 090f0c7..64e5a3c 100755 --- a/gimp-plugins/syncWeights.py +++ b/gimp-plugins/syncWeights.py @@ -203,5 +203,16 @@ def sync(path,flag): gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...") download_file_from_google_drive(file_id, destination,fileSize) + #enlighten + model = 'enlightening' + file_id = '1V8ARc2tDgUUpc11xiT5Y9HFQgC6Ug2T6' + fileSize = 0.035 #in MB + mFName = '200_net_G_A.pth' + if not os.path.isdir(path + '/' + model): + os.mkdir(path + '/' + model) + destination = path + '/' + model + '/' + mFName + if not os.path.isfile(destination): + gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...") + download_file_from_google_drive(file_id, destination,fileSize)