import torch import torch.nn as nn import math class _Residual_Block(nn.Module): def __init__(self): super(_Residual_Block, self).__init__() self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) self.in1 = nn.InstanceNorm2d(64, affine=True) self.relu = nn.LeakyReLU(0.2, inplace=True) self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) self.in2 = nn.InstanceNorm2d(64, affine=True) def forward(self, x): identity_data = x output = self.relu(self.in1(self.conv1(x))) output = self.in2(self.conv2(output)) output = torch.add(output,identity_data) return output class _NetG(nn.Module): def __init__(self): super(_NetG, self).__init__() self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False) self.relu = nn.LeakyReLU(0.2, inplace=True) self.residual = self.make_layer(_Residual_Block, 16) self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) self.bn_mid = nn.InstanceNorm2d(64, affine=True) self.upscale4x = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), nn.PixelShuffle(2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), nn.PixelShuffle(2), nn.LeakyReLU(0.2, inplace=True), ) self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() def make_layer(self, block, num_of_layer): layers = [] for _ in range(num_of_layer): layers.append(block()) return nn.Sequential(*layers) def forward(self, x): out = self.relu(self.conv_input(x)) residual = out out = self.residual(out) out = self.bn_mid(self.conv_mid(out)) out = torch.add(out,residual) out = self.upscale4x(out) out = self.conv_output(out) return out class _NetD(nn.Module): def __init__(self): super(_NetD, self).__init__() self.features = nn.Sequential( # input is (3) x 96 x 96 nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (64) x 96 x 96 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True), # state size. (64) x 96 x 96 nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # state size. (64) x 48 x 48 nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), # state size. (128) x 48 x 48 nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # state size. (256) x 24 x 24 nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), # state size. (256) x 12 x 12 nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), # state size. (512) x 12 x 12 nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), ) self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True) self.fc1 = nn.Linear(512 * 6 * 6, 1024) self.fc2 = nn.Linear(1024, 1) self.sigmoid = nn.Sigmoid() for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0.0, 0.02) elif isinstance(m, nn.BatchNorm2d): m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) def forward(self, input): out = self.features(input) # state size. (512) x 6 x 6 out = out.view(out.size(0), -1) # state size. (512 x 6 x 6) out = self.fc1(out) # state size. (1024) out = self.LeakyReLU(out) out = self.fc2(out) out = self.sigmoid(out) return out.view(-1, 1).squeeze(1)