mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
43 lines
1.4 KiB
Python
Executable File
43 lines
1.4 KiB
Python
Executable File
from collections import OrderedDict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .bn import ABN
|
|
|
|
|
|
class DenseModule(nn.Module):
|
|
def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
|
|
super(DenseModule, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.growth = growth
|
|
self.layers = layers
|
|
|
|
self.convs1 = nn.ModuleList()
|
|
self.convs3 = nn.ModuleList()
|
|
for i in range(self.layers):
|
|
self.convs1.append(nn.Sequential(OrderedDict([
|
|
("bn", norm_act(in_channels)),
|
|
("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
|
|
])))
|
|
self.convs3.append(nn.Sequential(OrderedDict([
|
|
("bn", norm_act(self.growth * bottleneck_factor)),
|
|
("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
|
|
dilation=dilation))
|
|
])))
|
|
in_channels += self.growth
|
|
|
|
@property
|
|
def out_channels(self):
|
|
return self.in_channels + self.growth * self.layers
|
|
|
|
def forward(self, x):
|
|
inputs = [x]
|
|
for i in range(self.layers):
|
|
x = torch.cat(inputs, dim=1)
|
|
x = self.convs1[i](x)
|
|
x = self.convs3[i](x)
|
|
inputs += [x]
|
|
|
|
return torch.cat(inputs, dim=1)
|