You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
3.1 KiB
Python

'''
Minor Modification from https://github.com/SaoYan/DnCNN-PyTorch SaoYan
Re-implemented by Yuqian Zhou
'''
import torch
import torch.nn as nn
class DnCNN(nn.Module):
'''
Original DnCNN model without input conditions
'''
def __init__(self, channels, num_of_layers=17):
super(DnCNN, self).__init__()
kernel_size = 3
padding = 1
features = 64
layers = []
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers-2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, input_x):
out = self.dncnn(input_x)
return out
class Estimation_direct(nn.Module):
'''
Noise estimator, with original 3 layers
'''
def __init__(self, input_channels = 1, output_channels = 3, num_of_layers=3):
super(Estimation_direct, self).__init__()
kernel_size = 3
padding = 1
features = 64
layers = []
layers.append(nn.Conv2d(in_channels=input_channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers-2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, input):
x = self.dncnn(input)
return x
class DnCNN_c(nn.Module):
def __init__(self, channels, num_of_layers=17, num_of_est=3):
super(DnCNN_c, self).__init__()
kernel_size = 3
padding = 1
features = 64
layers = []
layers.append(nn.Conv2d(in_channels=channels+ num_of_est, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.ReLU(inplace=True))
for _ in range(num_of_layers-2):
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
layers.append(nn.BatchNorm2d(features))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
self.dncnn = nn.Sequential(*layers)
def forward(self, x, c):
input_x = torch.cat([x, c], dim=1)
out = self.dncnn(input_x)
return out