You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

36 lines
1.0 KiB
Python

import numpy as np
import torch.nn as nn
from skimage.measure import compare_ssim as SSIM
from util.metrics import PSNR
class DeblurModel(nn.Module):
def __init__(self):
super(DeblurModel, self).__init__()
def get_input(self, data):
img = data['a']
inputs = img
targets = data['b']
inputs, targets = inputs.cuda(), targets.cuda()
return inputs, targets
def tensor2im(self, image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def get_images_and_metrics(self, inp, output, target):
inp = self.tensor2im(inp)
fake = self.tensor2im(output.data)
real = self.tensor2im(target.data)
psnr = PSNR(fake, real)
ssim = SSIM(fake, real, multichannel=True)
vis_img = np.hstack((inp, fake, real))
return psnr, ssim, vis_img
def get_model(model_config):
return DeblurModel()