mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-10-31 09:20:18 +00:00
91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
from __future__ import print_function
|
|
import argparse
|
|
import numpy as np
|
|
import torch
|
|
import cv2
|
|
import yaml
|
|
import os
|
|
from torchvision import models, transforms
|
|
from torch.autograd import Variable
|
|
import shutil
|
|
import glob
|
|
import tqdm
|
|
from util.metrics import PSNR
|
|
from albumentations import Compose, CenterCrop, PadIfNeeded
|
|
from PIL import Image
|
|
from ssim.ssimlib import SSIM
|
|
from models.networks import get_generator
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser('Test an image')
|
|
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
|
|
parser.add_argument('--weights_path', required=True, help='Weights path')
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def prepare_dirs(path):
|
|
if os.path.exists(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path)
|
|
|
|
|
|
def get_gt_image(path):
|
|
dir, filename = os.path.split(path)
|
|
base, seq = os.path.split(dir)
|
|
base, _ = os.path.split(base)
|
|
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
|
|
return img
|
|
|
|
|
|
def test_image(model, image_path):
|
|
img_transforms = transforms.Compose([
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
])
|
|
size_transform = Compose([
|
|
PadIfNeeded(736, 1280)
|
|
])
|
|
crop = CenterCrop(720, 1280)
|
|
img = cv2.imread(image_path)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img_s = size_transform(image=img)['image']
|
|
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
|
|
img_tensor = img_transforms(img_tensor)
|
|
with torch.no_grad():
|
|
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
|
|
result_image = model(img_tensor)
|
|
result_image = result_image[0].cpu().float().numpy()
|
|
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
|
|
result_image = crop(image=result_image)['image']
|
|
result_image = result_image.astype('uint8')
|
|
gt_image = get_gt_image(image_path)
|
|
_, filename = os.path.split(image_path)
|
|
psnr = PSNR(result_image, gt_image)
|
|
pilFake = Image.fromarray(result_image)
|
|
pilReal = Image.fromarray(gt_image)
|
|
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
|
|
return psnr, ssim
|
|
|
|
|
|
def test(model, files):
|
|
psnr = 0
|
|
ssim = 0
|
|
for file in tqdm.tqdm(files):
|
|
cur_psnr, cur_ssim = test_image(model, file)
|
|
psnr += cur_psnr
|
|
ssim += cur_ssim
|
|
print("PSNR = {}".format(psnr / len(files)))
|
|
print("SSIM = {}".format(ssim / len(files)))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = get_args()
|
|
with open('config/config.yaml') as cfg:
|
|
config = yaml.load(cfg)
|
|
model = get_generator(config['model'])
|
|
model.load_state_dict(torch.load(args.weights_path)['model'])
|
|
model = model.cuda()
|
|
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
|
|
test(model, filenames)
|