mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-10-31 09:20:18 +00:00
74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
'''
|
|
Minor Modification from https://github.com/SaoYan/DnCNN-PyTorch SaoYan
|
|
Re-implemented by Yuqian Zhou
|
|
'''
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
class DnCNN(nn.Module):
|
|
'''
|
|
Original DnCNN model without input conditions
|
|
'''
|
|
def __init__(self, channels, num_of_layers=17):
|
|
super(DnCNN, self).__init__()
|
|
kernel_size = 3
|
|
padding = 1
|
|
features = 64
|
|
layers = []
|
|
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
for _ in range(num_of_layers-2):
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.BatchNorm2d(features))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
|
self.dncnn = nn.Sequential(*layers)
|
|
def forward(self, input_x):
|
|
out = self.dncnn(input_x)
|
|
return out
|
|
|
|
|
|
class Estimation_direct(nn.Module):
|
|
'''
|
|
Noise estimator, with original 3 layers
|
|
'''
|
|
def __init__(self, input_channels = 1, output_channels = 3, num_of_layers=3):
|
|
super(Estimation_direct, self).__init__()
|
|
kernel_size = 3
|
|
padding = 1
|
|
features = 64
|
|
layers = []
|
|
layers.append(nn.Conv2d(in_channels=input_channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
for _ in range(num_of_layers-2):
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.BatchNorm2d(features))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=kernel_size, padding=padding, bias=False))
|
|
self.dncnn = nn.Sequential(*layers)
|
|
|
|
def forward(self, input):
|
|
x = self.dncnn(input)
|
|
return x
|
|
|
|
|
|
class DnCNN_c(nn.Module):
|
|
def __init__(self, channels, num_of_layers=17, num_of_est=3):
|
|
super(DnCNN_c, self).__init__()
|
|
kernel_size = 3
|
|
padding = 1
|
|
features = 64
|
|
layers = []
|
|
layers.append(nn.Conv2d(in_channels=channels+ num_of_est, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
for _ in range(num_of_layers-2):
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
|
layers.append(nn.BatchNorm2d(features))
|
|
layers.append(nn.ReLU(inplace=True))
|
|
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
|
self.dncnn = nn.Sequential(*layers)
|
|
def forward(self, x, c):
|
|
input_x = torch.cat([x, c], dim=1)
|
|
out = self.dncnn(input_x)
|
|
return out
|