octUpdate

pull/30/head
Kritik Soman 4 years ago
parent 4e14d7291e
commit a3fcc0cfc4

@ -0,0 +1,3 @@
# Default ignored files
/shelf/
/workspace.xml

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 2.7 (gimpenv)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Python 2.7 (gimpenv)" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/EnlightenGAN.iml" filepath="$PROJECT_DIR$/.idea/EnlightenGAN.iml" />
</modules>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
</component>
</project>

@ -0,0 +1,58 @@
Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------- LICENSE FOR EnlightenGAN --------------------------------
BSD License
For EnlightenGAN software
Copyright (c) 2019, Yifan Jiang and Zhangyang Wang
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License
For dcgan.torch software
Copyright (c) 2015, Facebook, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,50 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
zoom = 1 + 0.1*radom.randint(0,4)
osize = [int(400*zoom), int(600*zoom)]
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))
# elif opt.resize_or_crop == 'no':
# osize = [384, 512]
# transform_list.append(transforms.Scale(osize, Image.BICUBIC))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)

@ -0,0 +1,2 @@
from .modules import *
from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to

@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
# File : __init__.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
from .replicate import DataParallelWithCallback, patch_replication_callback

@ -0,0 +1,329 @@
# -*- coding: utf-8 -*-
# File : batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import collections
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
from .comm import SyncMaster
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
def _sum_ft(tensor):
"""sum over the first and last dimention"""
return tensor.sum(dim=0).sum(dim=-1)
def _unsqueeze_ft(tensor):
"""add new dementions at the front and the tail"""
return tensor.unsqueeze(0).unsqueeze(-1)
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
class _SynchronizedBatchNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
self._sync_master = SyncMaster(self._data_parallel_master)
self._is_parallel = False
self._parallel_id = None
self._slave_pipe = None
# customed batch norm statistics
self._moving_average_fraction = 1. - momentum
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
self.register_buffer('_running_iter', torch.ones(1))
self._tmp_running_mean = self.running_mean.clone() * self._running_iter
self._tmp_running_var = self.running_var.clone() * self._running_iter
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
else:
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)
def __data_parallel_replicate__(self, ctx, copy_id):
self._is_parallel = True
self._parallel_id = copy_id
# parallel_id == 0 means master device.
if self._parallel_id == 0:
ctx.sync_master = self._sync_master
else:
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
def _data_parallel_master(self, intermediates):
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
to_reduce = [i[1][:2] for i in intermediates]
to_reduce = [j for i in to_reduce for j in i] # flatten
target_gpus = [i[1].sum.get_device() for i in intermediates]
sum_size = sum([i[1].sum_size for i in intermediates])
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
outputs = []
for i, rec in enumerate(intermediates):
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
return outputs
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
return dest * alpha + delta * beta + bias
def _compute_mean_std(self, sum_, ssum, size):
"""Compute the mean and standard-deviation with sum and square-sum. This method
also maintains the moving average on the master device."""
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
mean = sum_ / size
sumvar = ssum - sum_ * mean
unbias_var = sumvar / (size - 1)
bias_var = sumvar / size
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
self.running_mean = self._tmp_running_mean / self._running_iter
self.running_var = self._tmp_running_var / self._running_iter
return mean, bias_var.clamp(self.eps) ** -0.5
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
mini-batch.
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm1d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
Args:
num_features: num_features from an expected input of size
`batch_size x num_features [x width]`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C)` or :math:`(N, C, L)`
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm1d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
of 3d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm2d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)`
- Output: :math:`(N, C, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm2d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
of 4d inputs
.. math::
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
This module differs from the built-in PyTorch BatchNorm3d as the mean and
standard-deviation are reduced across all devices during training.
For example, when one uses `nn.DataParallel` to wrap the network during
training, PyTorch's implementation normalize the tensor on each device using
the statistics only on that device, which accelerated the computation and
is also easy to implement, but the statistics might be inaccurate.
Instead, in this synchronized version, the statistics will be computed
over all training samples distributed on multiple devices.
Note that, for one-GPU or CPU-only case, this module behaves exactly same
as the built-in PyTorch implementation.
The mean and standard-deviation are calculated per-dimension over
the mini-batches and gamma and beta are learnable parameter vectors
of size C (where C is the input size).
During training, this layer keeps a running estimate of its computed mean
and variance. The running sum is kept with a default momentum of 0.1.
During evaluation, this running mean/variance is used for normalization.
Because the BatchNorm is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
or Spatio-temporal BatchNorm
Args:
num_features: num_features from an expected input of
size batch_size x num_features x depth x height x width
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
affine: a boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> # With Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = SynchronizedBatchNorm3d(100, affine=False)
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
>>> output = m(input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)

@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
# File : comm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import queue
import collections
import threading
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
class FutureResult(object):
"""A thread-safe future implementation. Used only as one-to-one pipe."""
def __init__(self):
self._result = None
self._lock = threading.Lock()
self._cond = threading.Condition(self._lock)
def put(self, result):
with self._lock:
assert self._result is None, 'Previous result has\'t been fetched.'
self._result = result
self._cond.notify()
def get(self):
with self._lock:
if self._result is None:
self._cond.wait()
res = self._result
self._result = None
return res
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
"""Pipe for master-slave communication."""
def run_slave(self, msg):
self.queue.put((self.identifier, msg))
ret = self.result.get()
self.queue.put(True)
return ret
class SyncMaster(object):
"""An abstract `SyncMaster` object.
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
and passed to a registered callback.
- After receiving the messages, the master device should gather the information and determine to message passed
back to each slave devices.
"""
def __init__(self, master_callback):
"""
Args:
master_callback: a callback to be invoked after having collected messages from slave devices.
"""
self._master_callback = master_callback
self._queue = queue.Queue()
self._registry = collections.OrderedDict()
self._activated = False
def register_slave(self, identifier):
"""
Register an slave device.
Args:
identifier: an identifier, usually is the device id.
Returns: a `SlavePipe` object which can be used to communicate with the master device.
"""
if self._activated:
assert self._queue.empty(), 'Queue is not clean before next initialization.'
self._activated = False
self._registry.clear()
future = FutureResult()
self._registry[identifier] = _MasterRegistry(future)
return SlavePipe(identifier, self._queue, future)
def run_master(self, master_msg):
"""
Main entry for the master device in each forward pass.
The messages were first collected from each devices (including the master device), and then
an callback will be invoked to compute the message to be sent back to each devices
(including the master device).
Args:
master_msg: the message that the master want to send to itself. This will be placed as the first
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
Returns: the message to be sent back to the master device.
"""
self._activated = True
intermediates = [(0, master_msg)]
for i in range(self.nr_slaves):
intermediates.append(self._queue.get())
results = self._master_callback(intermediates)
assert results[0][0] == 0, 'The first result should belongs to the master.'
for i, res in results:
if i == 0:
continue
self._registry[i].result.put(res)
for i in range(self.nr_slaves):
assert self._queue.get() is True
return results[0][1]
@property
def nr_slaves(self):
return len(self._registry)

@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# File : replicate.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import functools
from torch.nn.parallel.data_parallel import DataParallel
__all__ = [
'CallbackContext',
'execute_replication_callbacks',
'DataParallelWithCallback',
'patch_replication_callback'
]
class CallbackContext(object):
pass
def execute_replication_callbacks(modules):
"""
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Note that, as all modules are isomorphism, we assign each sub-module with a context
(shared among multiple copies of this module on different devices).
Through this context, different copies can share some information.
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
of any slave copies.
"""
master_copy = modules[0]
nr_modules = len(list(master_copy.modules()))
ctxs = [CallbackContext() for _ in range(nr_modules)]
for i, module in enumerate(modules):
for j, m in enumerate(module.modules()):
if hasattr(m, '__data_parallel_replicate__'):
m.__data_parallel_replicate__(ctxs[j], i)
class DataParallelWithCallback(DataParallel):
"""
Data Parallel with a replication callback.
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
original `replicate` function.
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
# sync_bn.__data_parallel_replicate__ will be invoked.
"""
def replicate(self, module, device_ids):
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
def patch_replication_callback(data_parallel):
"""
Monkey-patch an existing `DataParallel` object. Add the replication callback.
Useful when you have customized `DataParallel` implementation.
Examples:
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
> patch_replication_callback(sync_bn)
# this is equivalent to
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
"""
assert isinstance(data_parallel, DataParallel)
old_replicate = data_parallel.replicate
@functools.wraps(old_replicate)
def new_replicate(module, device_ids):
modules = old_replicate(module, device_ids)
execute_replication_callbacks(modules)
return modules
data_parallel.replicate = new_replicate

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# File : test_numeric_batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
import unittest
import torch
import torch.nn as nn
from torch.autograd import Variable
from sync_batchnorm.unittest import TorchTestCase
def handy_var(a, unbias=True):
n = a.size(0)
asum = a.sum(dim=0)
as_sum = (a ** 2).sum(dim=0) # a square sum
sumvar = as_sum - asum * asum / n
if unbias:
return sumvar / (n - 1)
else:
return sumvar / n
class NumericTestCase(TorchTestCase):
def testNumericBatchNorm(self):
a = torch.rand(16, 10)
bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
bn.train()
a_var1 = Variable(a, requires_grad=True)
b_var1 = bn(a_var1)
loss1 = b_var1.sum()
loss1.backward()
a_var2 = Variable(a, requires_grad=True)
a_mean2 = a_var2.mean(dim=0, keepdim=True)
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
b_var2 = (a_var2 - a_mean2) / a_std2
loss2 = b_var2.sum()
loss2.backward()
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
self.assertTensorClose(bn.running_var, handy_var(a))
self.assertTensorClose(a_var1.data, a_var2.data)
self.assertTensorClose(b_var1.data, b_var2.data)
self.assertTensorClose(a_var1.grad, a_var2.grad)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
# File : test_sync_batchnorm.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
import unittest
import torch
import torch.nn as nn
from torch.autograd import Variable
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
from sync_batchnorm.unittest import TorchTestCase
def handy_var(a, unbias=True):
n = a.size(0)
asum = a.sum(dim=0)
as_sum = (a ** 2).sum(dim=0) # a square sum
sumvar = as_sum - asum * asum / n
if unbias:
return sumvar / (n - 1)
else:
return sumvar / n
def _find_bn(module):
for m in module.modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
return m
class SyncTestCase(TorchTestCase):
def _syncParameters(self, bn1, bn2):
bn1.reset_parameters()
bn2.reset_parameters()
if bn1.affine and bn2.affine:
bn2.weight.data.copy_(bn1.weight.data)
bn2.bias.data.copy_(bn1.bias.data)
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
"""Check the forward and backward for the customized batch normalization."""
bn1.train(mode=is_train)
bn2.train(mode=is_train)
if cuda:
input = input.cuda()
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
input1 = Variable(input, requires_grad=True)
output1 = bn1(input1)
output1.sum().backward()
input2 = Variable(input, requires_grad=True)
output2 = bn2(input2)
output2.sum().backward()
self.assertTensorClose(input1.data, input2.data)
self.assertTensorClose(output1.data, output2.data)
self.assertTensorClose(input1.grad, input2.grad)
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
def testSyncBatchNormNormalTrain(self):
bn = nn.BatchNorm1d(10)
sync_bn = SynchronizedBatchNorm1d(10)
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
def testSyncBatchNormNormalEval(self):
bn = nn.BatchNorm1d(10)
sync_bn = SynchronizedBatchNorm1d(10)
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
def testSyncBatchNormSyncTrain(self):
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
def testSyncBatchNormSyncEval(self):
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
def testSyncBatchNorm2DSyncTrain(self):
bn = nn.BatchNorm2d(10)
sync_bn = SynchronizedBatchNorm2d(10)
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
bn.cuda()
sync_bn.cuda()
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
if __name__ == '__main__':
unittest.main()

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# File : unittest.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.
import unittest
import numpy as np
from torch.autograd import Variable
def as_numpy(v):
if isinstance(v, Variable):
v = v.data
return v.cpu().numpy()
class TorchTestCase(unittest.TestCase):
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
npa, npb = as_numpy(a), as_numpy(b)
self.assertTrue(
np.allclose(npa, npb, atol=atol),
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
)

@ -0,0 +1 @@
from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to

@ -0,0 +1,112 @@
# -*- coding: utf8 -*-
import torch.cuda as cuda
import torch.nn as nn
import torch
import collections
from torch.nn.parallel._functions import Gather
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
def async_copy_to(obj, dev, main_stream=None):
if torch.is_tensor(obj):
v = obj.cuda(dev, non_blocking=True)
if main_stream is not None:
v.data.record_stream(main_stream)
return v
elif isinstance(obj, collections.Mapping):
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
elif isinstance(obj, collections.Sequence):
return [async_copy_to(o, dev, main_stream) for o in obj]
else:
return obj
def dict_gather(outputs, target_device, dim=0):
"""
Gathers variables from different GPUs on a specified device
(-1 means the CPU), with dictionary support.
"""
def gather_map(outputs):
out = outputs[0]
if torch.is_tensor(out):
# MJY(20180330) HACK:: force nr_dims > 0
if out.dim() == 0:
outputs = [o.unsqueeze(0) for o in outputs]
return Gather.apply(target_device, dim, *outputs)
elif out is None:
return None
elif isinstance(out, collections.Mapping):
return {k: gather_map([o[k] for o in outputs]) for k in out}
elif isinstance(out, collections.Sequence):
return type(out)(map(gather_map, zip(*outputs)))
return gather_map(outputs)
class DictGatherDataParallel(nn.DataParallel):
def gather(self, outputs, output_device):
return dict_gather(outputs, output_device, dim=self.dim)
class UserScatteredDataParallel(DictGatherDataParallel):
def scatter(self, inputs, kwargs, device_ids):
assert len(inputs) == 1
inputs = inputs[0]
inputs = _async_copy_stream(inputs, device_ids)
inputs = [[i] for i in inputs]
assert len(kwargs) == 0
kwargs = [{} for _ in range(len(inputs))]
return inputs, kwargs
def user_scattered_collate(batch):
return batch
def _async_copy(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
for i, dev in zip(inputs, device_ids):
with cuda.device(dev):
outputs.append(async_copy_to(i, dev))
return tuple(outputs)
def _async_copy_stream(inputs, device_ids):
nr_devs = len(device_ids)
assert type(inputs) in (tuple, list)
assert len(inputs) == nr_devs
outputs = []
streams = [_get_stream(d) for d in device_ids]
for i, dev, stream in zip(inputs, device_ids, streams):
with cuda.device(dev):
main_stream = cuda.current_stream()
with cuda.stream(stream):
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
main_stream.wait_stream(stream)
return outputs
"""Adapted from: torch/nn/parallel/_functions.py"""
# background streams used for copying
_streams = None
def _get_stream(device):
"""Gets a background stream for copying between CPU and GPU"""
global _streams
if device == -1:
return None
if _streams is None:
_streams = [None] * cuda.device_count()
if _streams[device] is None: _streams[device] = cuda.Stream(device)
return _streams[device]

@ -0,0 +1,61 @@
import os
import torch
class BaseModel():
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and not opt.cFlag)else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda(device=gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
if torch.cuda.is_available():
ckpt = torch.load(save_path)
else:
ckpt = torch.load(save_path, map_location=torch.device("cpu"))
ckpt = {key.replace("module.", ""): value for key, value in ckpt.items()}
network.load_state_dict(ckpt)
def update_learning_rate():
pass

@ -0,0 +1,38 @@
def create_model(opt):
model = None
print(opt.model)
if opt.model == 'cycle_gan':
assert(opt.dataset_mode == 'unaligned')
from .cycle_gan_model import CycleGANModel
model = CycleGANModel()
elif opt.model == 'pix2pix':
assert(opt.dataset_mode == 'pix2pix')
from .pix2pix_model import Pix2PixModel
model = Pix2PixModel()
elif opt.model == 'pair':
# assert(opt.dataset_mode == 'pair')
# from .pair_model import PairModel
from .Unet_L1 import PairModel
model = PairModel()
elif opt.model == 'single':
# assert(opt.dataset_mode == 'unaligned')
from .single_model import SingleModel
model = SingleModel()
elif opt.model == 'temp':
# assert(opt.dataset_mode == 'unaligned')
from .temp_model import TempModel
model = TempModel()
elif opt.model == 'UNIT':
assert(opt.dataset_mode == 'unaligned')
from .unit_model import UNITModel
model = UNITModel()
elif opt.model == 'test':
assert(opt.dataset_mode == 'single')
from .test_model import TestModel
model = TestModel()
else:
raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model

File diff suppressed because it is too large Load Diff

@ -0,0 +1,496 @@
import numpy as np
import torch
from torch import nn
import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util as util
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
import random
from . import networks
import sys
class SingleModel(BaseModel):
def name(self):
return 'SingleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
nb = opt.batchSize
size = opt.fineSize
self.opt = opt
self.input_A = self.Tensor(nb, opt.input_nc, size, size)
self.input_B = self.Tensor(nb, opt.output_nc, size, size)
self.input_img = self.Tensor(nb, opt.input_nc, size, size)
self.input_A_gray = self.Tensor(nb, 1, size, size)
if opt.vgg > 0:
self.vgg_loss = networks.PerceptualLoss(opt)
if self.opt.IN_vgg:
self.vgg_patch_loss = networks.PerceptualLoss(opt)
self.vgg_patch_loss.cuda()
self.vgg_loss.cuda()
self.vgg = networks.load_vgg16("./model", self.gpu_ids)
self.vgg.eval()
for param in self.vgg.parameters():
param.requires_grad = False
elif opt.fcn > 0:
self.fcn_loss = networks.SemanticLoss(opt)
self.fcn_loss.cuda()
self.fcn = networks.load_fcn("./model")
self.fcn.eval()
for param in self.fcn.parameters():
param.requires_grad = False
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
skip = True if opt.skip > 0 else False
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt)
# self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
# opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, False)
if self.opt.patchD:
self.netD_P = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_patchD, opt.norm, use_sigmoid, self.gpu_ids, True)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netG_A, 'G_A', which_epoch)
# self.load_network(self.netG_B, 'G_B', which_epoch)
if self.isTrain:
self.load_network(self.netD_A, 'D_A', which_epoch)
if self.opt.patchD:
self.load_network(self.netD_P, 'D_P', which_epoch)
if self.isTrain:
self.old_lr = opt.lr
# self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
if opt.use_wgan:
self.criterionGAN = networks.DiscLossWGANGP()
else:
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
if opt.use_mse:
self.criterionCycle = torch.nn.MSELoss()
else:
self.criterionCycle = torch.nn.L1Loss()
self.criterionL1 = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
if self.opt.patchD:
self.optimizer_D_P = torch.optim.Adam(self.netD_P.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
print('---------- Networks initialized -------------')
networks.print_network(self.netG_A)
# networks.print_network(self.netG_B)
if self.isTrain:
networks.print_network(self.netD_A)
if self.opt.patchD:
networks.print_network(self.netD_P)
# networks.print_network(self.netD_B)
if opt.isTrain:
self.netG_A.train()
# self.netG_B.train()
else:
self.netG_A.eval()
# self.netG_B.eval()
print('-----------------------------------------------')
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
input_img = input['input_img']
input_A_gray = input['A_gray']
self.input_A.resize_(input_A.size()).copy_(input_A)
self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray)
self.input_B.resize_(input_B.size()).copy_(input_B)
self.input_img.resize_(input_img.size()).copy_(input_img)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
self.real_A_gray = Variable(self.input_A_gray, volatile=True)
if self.opt.noise > 0:
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
self.real_A = self.real_A + self.noise
if self.opt.input_linear:
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
# print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
if self.opt.skip == 1:
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray)
else:
self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray)
# self.rec_A = self.netG_B.forward(self.fake_B)
self.real_B = Variable(self.input_B, volatile=True)
def predict(self):
self.real_A = Variable(self.input_A, volatile=True)
self.real_A_gray = Variable(self.input_A_gray, volatile=True)
if self.opt.noise > 0:
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
self.real_A = self.real_A + self.noise
if self.opt.input_linear:
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
# print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
if self.opt.skip == 1:
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray)
else:
self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray)
# self.rec_A = self.netG_B.forward(self.fake_B)
real_A = util.tensor2im(self.real_A.data)
fake_B = util.tensor2im(self.fake_B.data)
A_gray = util.atten2im(self.real_A_gray.data)
# rec_A = util.tensor2im(self.rec_A.data)
# if self.opt.skip == 1:
# latent_real_A = util.tensor2im(self.latent_real_A.data)
# latent_show = util.latent2im(self.latent_real_A.data)
# max_image = util.max2im(self.fake_B.data, self.latent_real_A.data)
# return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
# ('latent_show', latent_show), ('max_image', max_image), ('A_gray', A_gray)])
# else:
# return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
# return OrderedDict([('fake_B', fake_B)])
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
# get image paths
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, netD, real, fake, use_ragan):
# Real
pred_real = netD.forward(real)
pred_fake = netD.forward(fake.detach())
if self.opt.use_wgan:
loss_D_real = pred_real.mean()
loss_D_fake = pred_fake.mean()
loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD,
real.data, fake.data)
elif self.opt.use_ragan and use_ragan:
loss_D = (self.criterionGAN(pred_real - torch.mean(pred_fake), True) +
self.criterionGAN(pred_fake - torch.mean(pred_real), False)) / 2
else:
loss_D_real = self.criterionGAN(pred_real, True)
loss_D_fake = self.criterionGAN(pred_fake, False)
loss_D = (loss_D_real + loss_D_fake) * 0.5
# loss_D.backward()
return loss_D
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
fake_B = self.fake_B
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, True)
self.loss_D_A.backward()
def backward_D_P(self):
if self.opt.hybrid_loss:
loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, False)
if self.opt.patchD_3 > 0:
for i in range(self.opt.patchD_3):
loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], False)
self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1)
else:
self.loss_D_P = loss_D_P
else:
loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, True)
if self.opt.patchD_3 > 0:
for i in range(self.opt.patchD_3):
loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], True)
self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1)
else:
self.loss_D_P = loss_D_P
if self.opt.D_P_times2:
self.loss_D_P = self.loss_D_P*2
self.loss_D_P.backward()
# def backward_D_B(self):
# fake_A = self.fake_A_pool.query(self.fake_A)
# self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def forward(self):
self.real_A = Variable(self.input_A)
self.real_B = Variable(self.input_B)
self.real_A_gray = Variable(self.input_A_gray)
self.real_img = Variable(self.input_img)
if self.opt.noise > 0:
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
self.real_A = self.real_A + self.noise
if self.opt.input_linear:
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
if self.opt.skip == 1:
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_img, self.real_A_gray)
else:
self.fake_B = self.netG_A.forward(self.real_img, self.real_A_gray)
if self.opt.patchD:
w = self.real_A.size(3)
h = self.real_A.size(2)
w_offset = random.randint(0, max(0, w - self.opt.patchSize - 1))
h_offset = random.randint(0, max(0, h - self.opt.patchSize - 1))
self.fake_patch = self.fake_B[:,:, h_offset:h_offset + self.opt.patchSize,
w_offset:w_offset + self.opt.patchSize]
self.real_patch = self.real_B[:,:, h_offset:h_offset + self.opt.patchSize,
w_offset:w_offset + self.opt.patchSize]
self.input_patch = self.real_A[:,:, h_offset:h_offset + self.opt.patchSize,
w_offset:w_offset + self.opt.patchSize]
if self.opt.patchD_3 > 0:
self.fake_patch_1 = []
self.real_patch_1 = []
self.input_patch_1 = []
w = self.real_A.size(3)
h = self.real_A.size(2)
for i in range(self.opt.patchD_3):
w_offset_1 = random.randint(0, max(0, w - self.opt.patchSize - 1))
h_offset_1 = random.randint(0, max(0, h - self.opt.patchSize - 1))
self.fake_patch_1.append(self.fake_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
w_offset_1:w_offset_1 + self.opt.patchSize])
self.real_patch_1.append(self.real_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
w_offset_1:w_offset_1 + self.opt.patchSize])
self.input_patch_1.append(self.real_A[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
w_offset_1:w_offset_1 + self.opt.patchSize])
# w_offset_2 = random.randint(0, max(0, w - self.opt.patchSize - 1))
# h_offset_2 = random.randint(0, max(0, h - self.opt.patchSize - 1))
# self.fake_patch_2 = self.fake_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
# w_offset_2:w_offset_2 + self.opt.patchSize]
# self.real_patch_2 = self.real_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
# w_offset_2:w_offset_2 + self.opt.patchSize]
# self.input_patch_2 = self.real_A[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
# w_offset_2:w_offset_2 + self.opt.patchSize]
def backward_G(self, epoch):
pred_fake = self.netD_A.forward(self.fake_B)
if self.opt.use_wgan:
self.loss_G_A = -pred_fake.mean()
elif self.opt.use_ragan:
pred_real = self.netD_A.forward(self.real_B)
self.loss_G_A = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) +
self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2
else:
self.loss_G_A = self.criterionGAN(pred_fake, True)
loss_G_A = 0
if self.opt.patchD:
pred_fake_patch = self.netD_P.forward(self.fake_patch)
if self.opt.hybrid_loss:
loss_G_A += self.criterionGAN(pred_fake_patch, True)
else:
pred_real_patch = self.netD_P.forward(self.real_patch)
loss_G_A += (self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) +
self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2
if self.opt.patchD_3 > 0:
for i in range(self.opt.patchD_3):
pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i])
if self.opt.hybrid_loss:
loss_G_A += self.criterionGAN(pred_fake_patch_1, True)
else:
pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i])
loss_G_A += (self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) +
self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2
if not self.opt.D_P_times2:
self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)
else:
self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2
else:
if not self.opt.D_P_times2:
self.loss_G_A += loss_G_A
else:
self.loss_G_A += loss_G_A*2
if epoch < 0:
vgg_w = 0
else:
vgg_w = 1
if self.opt.vgg > 0:
self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg,
self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
if self.opt.patch_vgg:
if not self.opt.IN_vgg:
loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg,
self.fake_patch, self.input_patch) * self.opt.vgg
else:
loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg,
self.fake_patch, self.input_patch) * self.opt.vgg
if self.opt.patchD_3 > 0:
for i in range(self.opt.patchD_3):
if not self.opt.IN_vgg:
loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg,
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
else:
loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg,
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1)
else:
self.loss_vgg_b += loss_vgg_patch
self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w
elif self.opt.fcn > 0:
self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn,
self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0
if self.opt.patchD:
loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn,
self.fake_patch, self.input_patch) * self.opt.fcn
if self.opt.patchD_3 > 0:
for i in range(self.opt.patchD_3):
loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn,
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn
self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1)
else:
self.loss_fcn_b += loss_fcn_patch
self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w
# self.loss_G = self.L1_AB + self.L1_BA
self.loss_G.backward()
# def optimize_parameters(self, epoch):
# # forward
# self.forward()
# # G_A and G_B
# self.optimizer_G.zero_grad()
# self.backward_G(epoch)
# self.optimizer_G.step()
# # D_A
# self.optimizer_D_A.zero_grad()
# self.backward_D_A()
# self.optimizer_D_A.step()
# if self.opt.patchD:
# self.forward()
# self.optimizer_D_P.zero_grad()
# self.backward_D_P()
# self.optimizer_D_P.step()
# D_B
# self.optimizer_D_B.zero_grad()
# self.backward_D_B()
# self.optimizer_D_B.step()
def optimize_parameters(self, epoch):
# forward
self.forward()
# G_A and G_B
self.optimizer_G.zero_grad()
self.backward_G(epoch)
self.optimizer_G.step()
# D_A
self.optimizer_D_A.zero_grad()
self.backward_D_A()
if not self.opt.patchD:
self.optimizer_D_A.step()
else:
# self.forward()
self.optimizer_D_P.zero_grad()
self.backward_D_P()
self.optimizer_D_A.step()
self.optimizer_D_P.step()
def get_current_errors(self, epoch):
D_A = self.loss_D_A.data[0]
D_P = self.loss_D_P.data[0] if self.opt.patchD else 0
G_A = self.loss_G_A.data[0]
if self.opt.vgg > 0:
vgg = self.loss_vgg_b.data[0]/self.opt.vgg if self.opt.vgg > 0 else 0
return OrderedDict([('D_A', D_A), ('G_A', G_A), ("vgg", vgg), ("D_P", D_P)])
elif self.opt.fcn > 0:
fcn = self.loss_fcn_b.data[0]/self.opt.fcn if self.opt.fcn > 0 else 0
return OrderedDict([('D_A', D_A), ('G_A', G_A), ("fcn", fcn), ("D_P", D_P)])
def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
fake_B = util.tensor2im(self.fake_B.data)
real_B = util.tensor2im(self.real_B.data)
if self.opt.skip > 0:
latent_real_A = util.tensor2im(self.latent_real_A.data)
latent_show = util.latent2im(self.latent_real_A.data)
if self.opt.patchD:
fake_patch = util.tensor2im(self.fake_patch.data)
real_patch = util.tensor2im(self.real_patch.data)
if self.opt.patch_vgg:
input_patch = util.tensor2im(self.input_patch.data)
if not self.opt.self_attention:
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
('fake_patch', fake_patch), ('input_patch', input_patch)])
else:
self_attention = util.atten2im(self.real_A_gray.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
('fake_patch', fake_patch), ('input_patch', input_patch), ('self_attention', self_attention)])
else:
if not self.opt.self_attention:
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
('fake_patch', fake_patch)])
else:
self_attention = util.atten2im(self.real_A_gray.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
('fake_patch', fake_patch), ('self_attention', self_attention)])
else:
if not self.opt.self_attention:
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
('latent_show', latent_show), ('real_B', real_B)])
else:
self_attention = util.atten2im(self.real_A_gray.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B),
('latent_real_A', latent_real_A), ('latent_show', latent_show),
('self_attention', self_attention)])
else:
if not self.opt.self_attention:
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
else:
self_attention = util.atten2im(self.real_A_gray.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B),
('self_attention', self_attention)])
def save(self, label):
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
if self.opt.patchD:
self.save_network(self.netD_P, 'D_P', label, self.gpu_ids)
# self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
# self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
def update_learning_rate(self):
if self.opt.new_lr:
lr = self.old_lr/2
else:
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D_A.param_groups:
param_group['lr'] = lr
if self.opt.patchD:
for param_group in self.optimizer_D_P.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr

@ -0,0 +1,32 @@
import random
import numpy as np
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images

@ -0,0 +1,182 @@
# from __future__ import print_function
import numpy as np
from PIL import Image
import inspect, re
import numpy as np
import torch
import os
import collections
from torch.optim import lr_scheduler
import torch.nn.init as init
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = np.maximum(image_numpy, 0)
image_numpy = np.minimum(image_numpy, 255)
return image_numpy.astype(imtype)
def atten2im(image_tensor, imtype=np.uint8):
image_tensor = image_tensor[0]
image_tensor = torch.cat((image_tensor, image_tensor, image_tensor), 0)
image_numpy = image_tensor.cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
image_numpy = image_numpy/(image_numpy.max()/255.0)
return image_numpy.astype(imtype)
def latent2im(image_tensor, imtype=np.uint8):
# image_tensor = (image_tensor - torch.min(image_tensor))/(torch.max(image_tensor)-torch.min(image_tensor))
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
image_numpy = np.maximum(image_numpy, 0)
image_numpy = np.minimum(image_numpy, 255)
return image_numpy.astype(imtype)
def max2im(image_1, image_2, imtype=np.uint8):
image_1 = image_1[0].cpu().float().numpy()
image_2 = image_2[0].cpu().float().numpy()
image_1 = (np.transpose(image_1, (1, 2, 0)) + 1) / 2.0 * 255.0
image_2 = (np.transpose(image_2, (1, 2, 0))) * 255.0
output = np.maximum(image_1, image_2)
output = np.maximum(output, 0)
output = np.minimum(output, 255)
return output.astype(imtype)
def variable2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].data.cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def info(object, spacing=10, collapse=1):
"""Print methods and doc strings.
Takes module, class, list, dictionary, or string."""
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
print( "\n".join(["%s %s" %
(method.ljust(spacing),
processFunc(str(getattr(object, method).__doc__)))
for method in methodList]) )
def varname(p):
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
if m:
return m.group(1)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
def get_model_list(dirname, key):
if os.path.exists(dirname) is False:
return None
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
if gen_models is None:
return None
gen_models.sort()
last_model_name = gen_models[-1]
return last_model_name
def load_vgg16(model_dir):
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
if not os.path.exists(model_dir):
os.mkdir(model_dir)
if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
vgg = Vgg16()
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
dst.data[:] = src
torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
vgg = Vgg16()
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
return vgg
def vgg_preprocess(batch):
tensortype = type(batch.data)
(r, g, b) = torch.chunk(batch, 3, dim = 1)
batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
mean = tensortype(batch.data.size())
mean[:, 0, :, :] = 103.939
mean[:, 1, :, :] = 116.779
mean[:, 2, :, :] = 123.680
batch = batch.sub(Variable(mean)) # subtract mean
return batch
def get_scheduler(optimizer, hyperparameters, iterations=-1):
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
scheduler = None # constant scheduler
elif hyperparameters['lr_policy'] == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
gamma=hyperparameters['gamma'], last_epoch=iterations)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
return scheduler
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
# print m.__class__.__name__
if init_type == 'gaussian':
init.normal(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant(m.bias.data, 0.0)
return init_fun

@ -0,0 +1,107 @@
import os
baseLoc = os.path.dirname(os.path.realpath(__file__)) + '/'
from gimpfu import *
import sys
sys.path.extend([baseLoc + 'gimpenv/lib/python2.7', baseLoc + 'gimpenv/lib/python2.7/site-packages',
baseLoc + 'gimpenv/lib/python2.7/site-packages/setuptools', baseLoc + 'EnlightenGAN'])
from argparse import Namespace
import cv2
import numpy as np
import torch
from models.models import create_model
from data.base_dataset import get_transform
def getEnlighten(input_image,cFlag):
opt = Namespace(D_P_times2=False, IN_vgg=False, aspect_ratio=1.0, batchSize=1,
checkpoints_dir=baseLoc+'weights/', dataroot='test_dataset',
dataset_mode='unaligned', display_id=1, display_port=8097,
display_single_pane_ncols=0, display_winsize=256, fcn=0,
fineSize=256, gpu_ids=[0], high_times=400, how_many=50,
hybrid_loss=False, identity=0.0, input_linear=False, input_nc=3,
instance_norm=0.0, isTrain=False, l1=10.0, lambda_A=10.0,
lambda_B=10.0, latent_norm=False, latent_threshold=False,
lighten=False, linear=False, linear_add=False, loadSize=286,
low_times=200, max_dataset_size='inf', model='single',
multiply=False, nThreads=1, n_layers_D=3, n_layers_patchD=3,
name='enlightening', ndf=64, new_lr=False, ngf=64, no_dropout=True,
no_flip=True, no_vgg_instance=False, noise=0, norm='instance',
norm_attention=False, ntest='inf', output_nc=3, patchD=False,
patchD_3=0, patchSize=64, patch_vgg=False, phase='test',
resize_or_crop='no', results_dir='./results/', self_attention=True,
serial_batches=True, skip=1.0, syn_norm=False, tanh=False,
times_residual=True, use_avgpool=0, use_mse=False, use_norm=1.0,
use_ragan=False, use_wgan=0.0, vary=1, vgg=0, vgg_choose='relu5_3',
vgg_maxpooling=False, vgg_mean=False, which_direction='AtoB',
which_epoch='200', which_model_netD='basic', which_model_netG='sid_unet_resize', cFlag=cFlag)
im = cv2.cvtColor(input_image,cv2.COLOR_RGB2BGR)
transform = get_transform(opt)
A_img = transform(im)
r, g, b = A_img[0] + 1, A_img[1] + 1, A_img[2] + 1
A_gray = 1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2.
A_gray = torch.unsqueeze(A_gray, 0)
data = {'A': A_img.unsqueeze(0), 'B': A_img.unsqueeze(0), 'A_gray': A_gray.unsqueeze(0), 'input_img': A_img.unsqueeze(0), 'A_paths': 'A_path', 'B_paths': 'B_path'}
model = create_model(opt)
model.set_input(data)
visuals = model.predict()
out = visuals['fake_B'].astype(np.uint8)
out = cv2.cvtColor(out,cv2.COLOR_BGR2RGB)
# cv2.imwrite("/Users/kritiksoman/PycharmProjects/new/out.png", out)
return out
def channelData(layer): # convert gimp image to numpy
region = layer.get_pixel_rgn(0, 0, layer.width, layer.height)
pixChars = region[:, :] # Take whole layer
bpp = region.bpp
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
return np.frombuffer(pixChars, dtype=np.uint8).reshape(layer.height, layer.width, bpp)
def createResultLayer(image, name, result):
rlBytes = np.uint8(result).tobytes();
rl = gimp.Layer(image, name, image.width, image.height, 0, 100, NORMAL_MODE)
region = rl.get_pixel_rgn(0, 0, rl.width, rl.height, True)
region[:, :] = rlBytes
image.add_layer(rl, 0)
gimp.displays_flush()
def Enlighten(img, layer,cFlag):
if torch.cuda.is_available() and not cFlag:
gimp.progress_init("(Using GPU) Enlighten " + layer.name + "...")
else:
gimp.progress_init("(Using CPU) Enlighten " + layer.name + "...")
imgmat = channelData(layer)
if imgmat.shape[2] == 4: # get rid of alpha channel
imgmat = imgmat[:,:,0:3]
cpy = getEnlighten(imgmat,cFlag)
createResultLayer(img, 'new_output', cpy)
register(
"enlighten",
"enlighten",
"Enlighten image based on deep learning.",
"Kritik Soman",
"Your",
"2020",
"enlighten...",
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
[(PF_IMAGE, "image", "Input image", None),
(PF_DRAWABLE, "drawable", "Input drawable", None),
(PF_BOOL, "fcpu", "Force CPU", False)
],
[],
Enlighten, menu="<Image>/Layer/GIML-ML")
main()

@ -203,5 +203,16 @@ def sync(path,flag):
gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...")
download_file_from_google_drive(file_id, destination,fileSize)
#enlighten
model = 'enlightening'
file_id = '1V8ARc2tDgUUpc11xiT5Y9HFQgC6Ug2T6'
fileSize = 0.035 #in MB
mFName = '200_net_G_A.pth'
if not os.path.isdir(path + '/' + model):
os.mkdir(path + '/' + model)
destination = path + '/' + model + '/' + mFName
if not os.path.isfile(destination):
gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...")
download_file_from_google_drive(file_id, destination,fileSize)

Loading…
Cancel
Save