denoiseCPUfix

pull/30/head
Kritik Soman 4 years ago
parent e62ce973ac
commit e606b8212e

@ -68,7 +68,7 @@ def denoiser(Img, c, pss, model, model_est, opt, cFlag):
elif opt.cond == 1: #if we use the estimated map directly
NM_tensor = torch.clamp(model_est(INoisy), 0., 1.)
if opt.refine == 1: #if we need to refine the map before putting it to the denoiser
NM_tensor_bundle = level_refine(NM_tensor, opt.refine_opt, 2*c) #refine_opt can be max, freq and their average
NM_tensor_bundle = level_refine(NM_tensor, opt.refine_opt, 2*c, cFlag) #refine_opt can be max, freq and their average
NM_tensor = NM_tensor_bundle[0]
noise_estimation_table = np.reshape(NM_tensor_bundle[1], (2 * c,))
if opt.zeroout == 1:
@ -81,7 +81,7 @@ def denoiser(Img, c, pss, model, model_est, opt, cFlag):
Out = torch.clamp(INoisy-Res, 0., 1.) #Output image after denoising
#get the maximum denoising result
max_NM_tensor = level_refine(NM_tensor, 1, 2*c)[0]
max_NM_tensor = level_refine(NM_tensor, 1, 2*c, cFlag)[0]
max_Res = model(INoisy, max_NM_tensor)
max_Out = torch.clamp(INoisy - max_Res, 0., 1.)
max_out_numpy = visual_va2np(max_Out, opt.color, opt.ps, pss, 1, opt.rescale, w, h, c)
@ -102,14 +102,14 @@ def denoiser(Img, c, pss, model, model_est, opt, cFlag):
if opt.color == 0: #if gray image
re_test = np.expand_dims(re_test[:, :, :, 0], 3)
re_test_tensor = torch.from_numpy(np.transpose(re_test, (0,3,1,2))).type(torch.FloatTensor)
if torch.cuda.is_available():
if torch.cuda.is_available() and not cFlag:
re_test_tensor = Variable(re_test_tensor.cuda(),volatile=True)
else:
re_test_tensor = Variable(re_test_tensor, volatile=True)
re_NM_tensor = torch.clamp(model_est(re_test_tensor), 0., 1.)
if opt.refine == 1: #if we need to refine the map before putting it to the denoiser
re_NM_tensor_bundle = level_refine(re_NM_tensor, opt.refine_opt, 2*c) #refine_opt can be max, freq and their average
re_NM_tensor_bundle = level_refine(re_NM_tensor, opt.refine_opt, 2*c, cFlag) #refine_opt can be max, freq and their average
re_NM_tensor = re_NM_tensor_bundle[0]
re_Res = model(re_test_tensor, re_NM_tensor)
Out2 = torch.clamp(re_test_tensor - re_Res, 0., 1.)

@ -277,7 +277,7 @@ def zeroing_out_maps(lm, keep=0):
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
return RF_tensor
def level_refine(NM_tensor, ref_mode, chn=3):
def level_refine(NM_tensor, ref_mode, chn=3,cFlag=False):
'''
Description: To refine the estimated noise level maps
[Input] the noise map tensor, and a refinement mode
@ -308,7 +308,7 @@ def level_refine(NM_tensor, ref_mode, chn=3):
noise_map[n,:,:,:] = np.reshape(np.tile(nl_list[n], NM_tensor.size()[2] * NM_tensor.size()[3]),
(chn, NM_tensor.size()[2], NM_tensor.size()[3]))
RF_tensor = torch.from_numpy(noise_map).type(torch.FloatTensor)
if torch.cuda.is_available():
if torch.cuda.is_available() and not cFlag:
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
else:
RF_tensor = Variable(RF_tensor,volatile=True)

@ -20,11 +20,11 @@ def clrImg(Img,cFlag):
rescale=1, scale=1, spat_n=0, test_data='real_night', test_data_gnd='Set12',
test_noise_level=None, wbin=512, zeroout=0)
c = 1 if opt.color == 0 else 3
net = DnCNN_c(channels=c, num_of_layers=opt.num_of_layers, num_of_est=2 * c)
est_net = Estimation_direct(c, 2 * c)
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids)
model_est = nn.DataParallel(est_net, device_ids=device_ids)# Estimator Model
model = DnCNN_c(channels=c, num_of_layers=opt.num_of_layers, num_of_est=2 * c)
model_est = Estimation_direct(c, 2 * c)
# device_ids = [0]
# model = nn.DataParallel(net, device_ids=device_ids)
# model_est = nn.DataParallel(est_net, device_ids=device_ids)# Estimator Model
if torch.cuda.is_available() and not cFlag:
ckpt_est = torch.load(baseLoc+'weights/deepdenoise/est_net.pth')
ckpt = torch.load(baseLoc+'weights/deepdenoise/net.pth')
@ -33,6 +33,10 @@ def clrImg(Img,cFlag):
else:
ckpt = torch.load(baseLoc+'weights/deepdenoise/net.pth',map_location=torch.device("cpu"))
ckpt_est = torch.load(baseLoc+'weights/deepdenoise/est_net.pth',map_location=torch.device("cpu"))
ckpt = {key.replace("module.",""):value for key,value in ckpt.items()}
ckpt_est = {key.replace("module.",""):value for key,value in ckpt_est.items()}
model.load_state_dict(ckpt)
model.eval()
model_est.load_state_dict(ckpt_est)

Loading…
Cancel
Save