mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
235 lines
7.5 KiB
Python
Executable File
235 lines
7.5 KiB
Python
Executable File
from os import path
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.autograd as autograd
|
|
import torch.cuda.comm as comm
|
|
from torch.autograd.function import once_differentiable
|
|
from torch.utils.cpp_extension import load
|
|
|
|
_src_path = path.join(path.dirname(path.abspath(__file__)), "src")
|
|
_backend = load(name="inplace_abn",
|
|
extra_cflags=["-O3"],
|
|
sources=[path.join(_src_path, f) for f in [
|
|
"inplace_abn.cpp",
|
|
"inplace_abn_cpu.cpp",
|
|
"inplace_abn_cuda.cu",
|
|
"inplace_abn_cuda_half.cu"
|
|
]],
|
|
extra_cuda_cflags=["--expt-extended-lambda"])
|
|
|
|
# Activation names
|
|
ACT_RELU = "relu"
|
|
ACT_LEAKY_RELU = "leaky_relu"
|
|
ACT_ELU = "elu"
|
|
ACT_NONE = "none"
|
|
|
|
|
|
def _check(fn, *args, **kwargs):
|
|
success = fn(*args, **kwargs)
|
|
if not success:
|
|
raise RuntimeError("CUDA Error encountered in {}".format(fn))
|
|
|
|
|
|
def _broadcast_shape(x):
|
|
out_size = []
|
|
for i, s in enumerate(x.size()):
|
|
if i != 1:
|
|
out_size.append(1)
|
|
else:
|
|
out_size.append(s)
|
|
return out_size
|
|
|
|
|
|
def _reduce(x):
|
|
if len(x.size()) == 2:
|
|
return x.sum(dim=0)
|
|
else:
|
|
n, c = x.size()[0:2]
|
|
return x.contiguous().view((n, c, -1)).sum(2).sum(0)
|
|
|
|
|
|
def _count_samples(x):
|
|
count = 1
|
|
for i, s in enumerate(x.size()):
|
|
if i != 1:
|
|
count *= s
|
|
return count
|
|
|
|
|
|
def _act_forward(ctx, x):
|
|
if ctx.activation == ACT_LEAKY_RELU:
|
|
_backend.leaky_relu_forward(x, ctx.slope)
|
|
elif ctx.activation == ACT_ELU:
|
|
_backend.elu_forward(x)
|
|
elif ctx.activation == ACT_NONE:
|
|
pass
|
|
|
|
|
|
def _act_backward(ctx, x, dx):
|
|
if ctx.activation == ACT_LEAKY_RELU:
|
|
_backend.leaky_relu_backward(x, dx, ctx.slope)
|
|
elif ctx.activation == ACT_ELU:
|
|
_backend.elu_backward(x, dx)
|
|
elif ctx.activation == ACT_NONE:
|
|
pass
|
|
|
|
|
|
class InPlaceABN(autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x, weight, bias, running_mean, running_var,
|
|
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
|
|
# Save context
|
|
ctx.training = training
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.activation = activation
|
|
ctx.slope = slope
|
|
ctx.affine = weight is not None and bias is not None
|
|
|
|
# Prepare inputs
|
|
count = _count_samples(x)
|
|
x = x.contiguous()
|
|
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
|
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
|
|
|
if ctx.training:
|
|
mean, var = _backend.mean_var(x)
|
|
|
|
# Update running stats
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
|
|
|
|
# Mark in-place modified tensors
|
|
ctx.mark_dirty(x, running_mean, running_var)
|
|
else:
|
|
mean, var = running_mean.contiguous(), running_var.contiguous()
|
|
ctx.mark_dirty(x)
|
|
|
|
# BN forward + activation
|
|
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
|
_act_forward(ctx, x)
|
|
|
|
# Output
|
|
ctx.var = var
|
|
ctx.save_for_backward(x, var, weight, bias)
|
|
return x
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
z, var, weight, bias = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# Undo activation
|
|
_act_backward(ctx, z, dz)
|
|
|
|
if ctx.training:
|
|
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
|
else:
|
|
# TODO: implement simplified CUDA backward for inference mode
|
|
edz = dz.new_zeros(dz.size(1))
|
|
eydz = dz.new_zeros(dz.size(1))
|
|
|
|
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
|
dweight = eydz * weight.sign() if ctx.affine else None
|
|
dbias = edz if ctx.affine else None
|
|
|
|
return dx, dweight, dbias, None, None, None, None, None, None, None
|
|
|
|
class InPlaceABNSync(autograd.Function):
|
|
@classmethod
|
|
def forward(cls, ctx, x, weight, bias, running_mean, running_var,
|
|
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
|
|
# Save context
|
|
ctx.training = training
|
|
ctx.momentum = momentum
|
|
ctx.eps = eps
|
|
ctx.activation = activation
|
|
ctx.slope = slope
|
|
ctx.affine = weight is not None and bias is not None
|
|
|
|
# Prepare inputs
|
|
ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
|
|
|
#count = _count_samples(x)
|
|
batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
|
|
|
|
x = x.contiguous()
|
|
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
|
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
|
|
|
if ctx.training:
|
|
mean, var = _backend.mean_var(x)
|
|
if ctx.world_size>1:
|
|
# get global batch size
|
|
if equal_batches:
|
|
batch_size *= ctx.world_size
|
|
else:
|
|
dist.all_reduce(batch_size, dist.ReduceOp.SUM)
|
|
|
|
ctx.factor = x.shape[0]/float(batch_size.item())
|
|
|
|
mean_all = mean.clone() * ctx.factor
|
|
dist.all_reduce(mean_all, dist.ReduceOp.SUM)
|
|
|
|
var_all = (var + (mean - mean_all) ** 2) * ctx.factor
|
|
dist.all_reduce(var_all, dist.ReduceOp.SUM)
|
|
|
|
mean = mean_all
|
|
var = var_all
|
|
|
|
# Update running stats
|
|
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
|
count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
|
|
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
|
|
|
|
# Mark in-place modified tensors
|
|
ctx.mark_dirty(x, running_mean, running_var)
|
|
else:
|
|
mean, var = running_mean.contiguous(), running_var.contiguous()
|
|
ctx.mark_dirty(x)
|
|
|
|
# BN forward + activation
|
|
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
|
_act_forward(ctx, x)
|
|
|
|
# Output
|
|
ctx.var = var
|
|
ctx.save_for_backward(x, var, weight, bias)
|
|
return x
|
|
|
|
@staticmethod
|
|
@once_differentiable
|
|
def backward(ctx, dz):
|
|
z, var, weight, bias = ctx.saved_tensors
|
|
dz = dz.contiguous()
|
|
|
|
# Undo activation
|
|
_act_backward(ctx, z, dz)
|
|
|
|
if ctx.training:
|
|
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
|
edz_local = edz.clone()
|
|
eydz_local = eydz.clone()
|
|
|
|
if ctx.world_size>1:
|
|
edz *= ctx.factor
|
|
dist.all_reduce(edz, dist.ReduceOp.SUM)
|
|
|
|
eydz *= ctx.factor
|
|
dist.all_reduce(eydz, dist.ReduceOp.SUM)
|
|
else:
|
|
edz_local = edz = dz.new_zeros(dz.size(1))
|
|
eydz_local = eydz = dz.new_zeros(dz.size(1))
|
|
|
|
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
|
dweight = eydz_local * weight.sign() if ctx.affine else None
|
|
dbias = edz_local if ctx.affine else None
|
|
|
|
return dx, dweight, dbias, None, None, None, None, None, None, None
|
|
|
|
inplace_abn = InPlaceABN.apply
|
|
inplace_abn_sync = InPlaceABNSync.apply
|
|
|
|
__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
|