mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
131 lines
4.8 KiB
Python
Executable File
131 lines
4.8 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as functional
|
|
|
|
try:
|
|
from queue import Queue
|
|
except ImportError:
|
|
from Queue import Queue
|
|
|
|
from .functions import *
|
|
|
|
|
|
class ABN(nn.Module):
|
|
"""Activated Batch Normalization
|
|
|
|
This gathers a `BatchNorm2d` and an activation function in a single module
|
|
"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
|
"""Creates an Activated Batch Normalization module
|
|
|
|
Parameters
|
|
----------
|
|
num_features : int
|
|
Number of feature channels in the input and output.
|
|
eps : float
|
|
Small constant to prevent numerical issues.
|
|
momentum : float
|
|
Momentum factor applied to compute running statistics as.
|
|
affine : bool
|
|
If `True` apply learned scale and shift transformation after normalization.
|
|
activation : str
|
|
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
|
slope : float
|
|
Negative slope for the `leaky_relu` activation.
|
|
"""
|
|
super(ABN, self).__init__()
|
|
self.num_features = num_features
|
|
self.affine = affine
|
|
self.eps = eps
|
|
self.momentum = momentum
|
|
self.activation = activation
|
|
self.slope = slope
|
|
if self.affine:
|
|
self.weight = nn.Parameter(torch.ones(num_features))
|
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
|
else:
|
|
self.register_parameter('weight', None)
|
|
self.register_parameter('bias', None)
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.constant_(self.running_mean, 0)
|
|
nn.init.constant_(self.running_var, 1)
|
|
if self.affine:
|
|
nn.init.constant_(self.weight, 1)
|
|
nn.init.constant_(self.bias, 0)
|
|
|
|
def forward(self, x):
|
|
x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
|
|
self.training, self.momentum, self.eps)
|
|
|
|
if self.activation == ACT_RELU:
|
|
return functional.relu(x, inplace=True)
|
|
elif self.activation == ACT_LEAKY_RELU:
|
|
return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
|
|
elif self.activation == ACT_ELU:
|
|
return functional.elu(x, inplace=True)
|
|
else:
|
|
return x
|
|
|
|
def __repr__(self):
|
|
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
|
' affine={affine}, activation={activation}'
|
|
if self.activation == "leaky_relu":
|
|
rep += ', slope={slope})'
|
|
else:
|
|
rep += ')'
|
|
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
|
|
|
|
|
class InPlaceABN(ABN):
|
|
"""InPlace Activated Batch Normalization"""
|
|
|
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
|
"""Creates an InPlace Activated Batch Normalization module
|
|
|
|
Parameters
|
|
----------
|
|
num_features : int
|
|
Number of feature channels in the input and output.
|
|
eps : float
|
|
Small constant to prevent numerical issues.
|
|
momentum : float
|
|
Momentum factor applied to compute running statistics as.
|
|
affine : bool
|
|
If `True` apply learned scale and shift transformation after normalization.
|
|
activation : str
|
|
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
|
slope : float
|
|
Negative slope for the `leaky_relu` activation.
|
|
"""
|
|
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
|
|
|
|
def forward(self, x):
|
|
return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
|
|
self.training, self.momentum, self.eps, self.activation, self.slope)
|
|
|
|
|
|
class InPlaceABNSync(ABN):
|
|
"""InPlace Activated Batch Normalization with cross-GPU synchronization
|
|
This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
|
|
"""
|
|
|
|
def forward(self, x):
|
|
return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
|
|
self.training, self.momentum, self.eps, self.activation, self.slope)
|
|
|
|
def __repr__(self):
|
|
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
|
' affine={affine}, activation={activation}'
|
|
if self.activation == "leaky_relu":
|
|
rep += ', slope={slope})'
|
|
else:
|
|
rep += ')'
|
|
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
|
|
|
|