You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
5.4 KiB
Python

import torch
import torch.nn as nn
import math
class _Residual_Block(nn.Module):
def __init__(self):
super(_Residual_Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.in1 = nn.InstanceNorm2d(64, affine=True)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.in2 = nn.InstanceNorm2d(64, affine=True)
def forward(self, x):
identity_data = x
output = self.relu(self.in1(self.conv1(x)))
output = self.in2(self.conv2(output))
output = torch.add(output,identity_data)
return output
class _NetG(nn.Module):
def __init__(self):
super(_NetG, self).__init__()
self.conv_input = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.residual = self.make_layer(_Residual_Block, 16)
self.conv_mid = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn_mid = nn.InstanceNorm2d(64, affine=True)
self.upscale4x = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv_output = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4, bias=False)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
def make_layer(self, block, num_of_layer):
layers = []
for _ in range(num_of_layer):
layers.append(block())
return nn.Sequential(*layers)
def forward(self, x):
out = self.relu(self.conv_input(x))
residual = out
out = self.residual(out)
out = self.bn_mid(self.conv_mid(out))
out = torch.add(out,residual)
out = self.upscale4x(out)
out = self.conv_output(out)
return out
class _NetD(nn.Module):
def __init__(self):
super(_NetD, self).__init__()
self.features = nn.Sequential(
# input is (3) x 96 x 96
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (64) x 96 x 96
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# state size. (64) x 96 x 96
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# state size. (64) x 48 x 48
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# state size. (128) x 48 x 48
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# state size. (256) x 24 x 24
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# state size. (256) x 12 x 12
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# state size. (512) x 12 x 12
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
)
self.LeakyReLU = nn.LeakyReLU(0.2, inplace=True)
self.fc1 = nn.Linear(512 * 6 * 6, 1024)
self.fc2 = nn.Linear(1024, 1)
self.sigmoid = nn.Sigmoid()
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def forward(self, input):
out = self.features(input)
# state size. (512) x 6 x 6
out = out.view(out.size(0), -1)
# state size. (512 x 6 x 6)
out = self.fc1(out)
# state size. (1024)
out = self.LeakyReLU(out)
out = self.fc2(out)
out = self.sigmoid(out)
return out.view(-1, 1).squeeze(1)