You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GIMP-ML/gimp-plugins/DeblurGANv2/test_metrics.py

91 lines
2.6 KiB
Python

from __future__ import print_function
import argparse
import numpy as np
import torch
import cv2
import yaml
import os
from torchvision import models, transforms
from torch.autograd import Variable
import shutil
import glob
import tqdm
from util.metrics import PSNR
from albumentations import Compose, CenterCrop, PadIfNeeded
from PIL import Image
from ssim.ssimlib import SSIM
from models.networks import get_generator
def get_args():
parser = argparse.ArgumentParser('Test an image')
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
parser.add_argument('--weights_path', required=True, help='Weights path')
return parser.parse_args()
def prepare_dirs(path):
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path)
def get_gt_image(path):
dir, filename = os.path.split(path)
base, seq = os.path.split(dir)
base, _ = os.path.split(base)
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
return img
def test_image(model, image_path):
img_transforms = transforms.Compose([
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
size_transform = Compose([
PadIfNeeded(736, 1280)
])
crop = CenterCrop(720, 1280)
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img_s = size_transform(image=img)['image']
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
img_tensor = img_transforms(img_tensor)
with torch.no_grad():
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
result_image = model(img_tensor)
result_image = result_image[0].cpu().float().numpy()
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
result_image = crop(image=result_image)['image']
result_image = result_image.astype('uint8')
gt_image = get_gt_image(image_path)
_, filename = os.path.split(image_path)
psnr = PSNR(result_image, gt_image)
pilFake = Image.fromarray(result_image)
pilReal = Image.fromarray(gt_image)
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
return psnr, ssim
def test(model, files):
psnr = 0
ssim = 0
for file in tqdm.tqdm(files):
cur_psnr, cur_ssim = test_image(model, file)
psnr += cur_psnr
ssim += cur_ssim
print("PSNR = {}".format(psnr / len(files)))
print("SSIM = {}".format(ssim / len(files)))
if __name__ == '__main__':
args = get_args()
with open('config/config.yaml') as cfg:
config = yaml.load(cfg)
model = get_generator(config['model'])
model.load_state_dict(torch.load(args.weights_path)['model'])
model = model.cuda()
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
test(model, filenames)