import torch import argparse import torch.nn as nn import net import cv2 import os from torchvision import transforms import torch.nn.functional as F import numpy as np import time def get_args(): # Training settings parser = argparse.ArgumentParser(description='PyTorch Super Res Example') parser.add_argument('--size_h', type=int, default=320, help="height size of input image") parser.add_argument('--size_w', type=int, default=320, help="width size of input image") parser.add_argument('--imgDir', type=str, required=True, help="directory of image") parser.add_argument('--trimapDir', type=str, required=True, help="directory of trimap") parser.add_argument('--cuda', action='store_true', help='use cuda?') parser.add_argument('--resume', type=str, required=True, help="checkpoint that model resume from") parser.add_argument('--saveDir', type=str, required=True, help="where prediction result save to") parser.add_argument('--alphaDir', type=str, default='', help="directory of gt") parser.add_argument('--stage', type=int, required=True, choices=[0,1,2,3], help="backbone stage") parser.add_argument('--not_strict', action='store_true', help='not copy ckpt strict?') parser.add_argument('--crop_or_resize', type=str, default="whole", choices=["resize", "crop", "whole"], help="how manipulate image before test") parser.add_argument('--max_size', type=int, default=1600, help="max size of test image") args = parser.parse_args() print(args) return args def gen_dataset(imgdir, trimapdir): sample_set = [] img_ids = os.listdir(imgdir) img_ids.sort() cnt = len(img_ids) cur = 1 for img_id in img_ids: img_name = os.path.join(imgdir, img_id) trimap_name = os.path.join(trimapdir, img_id) assert(os.path.exists(img_name)) assert(os.path.exists(trimap_name)) sample_set.append((img_name, trimap_name)) return sample_set def compute_gradient(img): x = cv2.Sobel(img, cv2.CV_16S, 1, 0) y = cv2.Sobel(img, cv2.CV_16S, 0, 1) absX = cv2.convertScaleAbs(x) absY = cv2.convertScaleAbs(y) grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY) return grad # inference once for image, return numpy def inference_once(args, model, scale_img, scale_trimap, aligned=True): if aligned: assert(scale_img.shape[0] == args.size_h) assert(scale_img.shape[1] == args.size_w) normalize = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225]) ]) scale_img_rgb = cv2.cvtColor(scale_img, cv2.COLOR_BGR2RGB) # first, 0-255 to 0-1 # second, x-mean/std and HWC to CHW tensor_img = normalize(scale_img_rgb).unsqueeze(0) scale_grad = compute_gradient(scale_img) #tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) tensor_trimap = torch.from_numpy(scale_trimap.astype(np.float32)[np.newaxis, np.newaxis, :, :]) tensor_grad = torch.from_numpy(scale_grad.astype(np.float32)[np.newaxis, np.newaxis, :, :]) if args.cuda: tensor_img = tensor_img.cuda() tensor_trimap = tensor_trimap.cuda() tensor_grad = tensor_grad.cuda() #print('Img Shape:{} Trimap Shape:{}'.format(img.shape, trimap.shape)) input_t = torch.cat((tensor_img, tensor_trimap / 255.), 1) # forward if args.stage <= 1: # stage 1 pred_mattes, _ = model(input_t) else: # stage 2, 3 _, pred_mattes = model(input_t) pred_mattes = pred_mattes.data if args.cuda: pred_mattes = pred_mattes.cpu() pred_mattes = pred_mattes.numpy()[0, 0, :, :] return pred_mattes # forward for a full image by crop method def inference_img_by_crop(args, model, img, trimap): # crop the pictures, and forward one by one h, w, c = img.shape origin_pred_mattes = np.zeros((h, w), dtype=np.float32) marks = np.zeros((h, w), dtype=np.float32) for start_h in range(0, h, args.size_h): end_h = start_h + args.size_h for start_w in range(0, w, args.size_w): end_w = start_w + args.size_w crop_img = img[start_h: end_h, start_w: end_w, :] crop_trimap = trimap[start_h: end_h, start_w: end_w] crop_origin_h = crop_img.shape[0] crop_origin_w = crop_img.shape[1] #print("startH:{} startW:{} H:{} W:{}".format(start_h, start_w, crop_origin_h, crop_origin_w)) if len(np.where(crop_trimap == 128)[0]) <= 0: continue # egde patch in the right or bottom if crop_origin_h != args.size_h or crop_origin_w != args.size_w: crop_img = cv2.resize(crop_img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR) crop_trimap = cv2.resize(crop_trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR) # inference for each crop image patch pred_mattes = inference_once(args, model, crop_img, crop_trimap) if crop_origin_h != args.size_h or crop_origin_w != args.size_w: pred_mattes = cv2.resize(pred_mattes, (crop_origin_w, crop_origin_h), interpolation=cv2.INTER_LINEAR) origin_pred_mattes[start_h: end_h, start_w: end_w] += pred_mattes marks[start_h: end_h, start_w: end_w] += 1 # smooth for overlap part marks[marks <= 0] = 1. origin_pred_mattes /= marks return origin_pred_mattes # forward for a full image by resize method def inference_img_by_resize(args, model, img, trimap): h, w, c = img.shape # resize for network input, to Tensor scale_img = cv2.resize(img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR) scale_trimap = cv2.resize(trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR) pred_mattes = inference_once(args, model, scale_img, scale_trimap) # resize to origin size origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR) assert(origin_pred_mattes.shape == trimap.shape) return origin_pred_mattes # forward a whole image def inference_img_whole(args, model, img, trimap): h, w, c = img.shape new_h = min(args.max_size, h - (h % 32)) new_w = min(args.max_size, w - (w % 32)) # resize for network input, to Tensor scale_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) scale_trimap = cv2.resize(trimap, (new_w, new_h), interpolation=cv2.INTER_LINEAR) pred_mattes = inference_once(args, model, scale_img, scale_trimap, aligned=False) # resize to origin size origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR) assert(origin_pred_mattes.shape == trimap.shape) return origin_pred_mattes def main(): print("===> Loading args") args = get_args() print("===> Environment init") #os.environ["CUDA_VISIBLE_DEVICES"] = "0" if args.cuda and not torch.cuda.is_available(): raise Exception("No GPU found, please run without --cuda") model = net.VGG16(args) ckpt = torch.load(args.resume) if args.not_strict: model.load_state_dict(ckpt['state_dict'], strict=False) else: model.load_state_dict(ckpt['state_dict'], strict=True) if args.cuda: model = model.cuda() print("===> Load dataset") dataset = gen_dataset(args.imgDir, args.trimapDir) mse_diffs = 0. sad_diffs = 0. cnt = len(dataset) cur = 0 t0 = time.time() for img_path, trimap_path in dataset: img = cv2.imread(img_path) trimap = cv2.imread(trimap_path)[:, :, 0] assert(img.shape[:2] == trimap.shape[:2]) img_info = (img_path.split('/')[-1], img.shape[0], img.shape[1]) cur += 1 print('[{}/{}] {}'.format(cur, cnt, img_info[0])) with torch.no_grad(): torch.cuda.empty_cache() if args.crop_or_resize == "whole": origin_pred_mattes = inference_img_whole(args, model, img, trimap) elif args.crop_or_resize == "crop": origin_pred_mattes = inference_img_by_crop(args, model, img, trimap) else: origin_pred_mattes = inference_img_by_resize(args, model, img, trimap) # only attention unknown region origin_pred_mattes[trimap == 255] = 1. origin_pred_mattes[trimap == 0 ] = 0. # origin trimap pixel = float((trimap == 128).sum()) # eval if gt alpha is given if args.alphaDir != '': alpha_name = os.path.join(args.alphaDir, img_info[0]) assert(os.path.exists(alpha_name)) alpha = cv2.imread(alpha_name)[:, :, 0] / 255. assert(alpha.shape == origin_pred_mattes.shape) #x1 = (alpha[trimap == 255] == 1.0).sum() # x3 #x2 = (alpha[trimap == 0] == 0.0).sum() # x5 #x3 = (trimap == 255).sum() #x4 = (trimap == 128).sum() #x5 = (trimap == 0).sum() #x6 = trimap.size # sum(x3,x4,x5) #x7 = (alpha[trimap == 255] < 1.0).sum() # 0 #x8 = (alpha[trimap == 0] > 0).sum() # #print(x1, x2, x3, x4, x5, x6, x7, x8) #assert(x1 == x3) #assert(x2 == x5) #assert(x6 == x3 + x4 + x5) #assert(x7 == 0) #assert(x8 == 0) mse_diff = ((origin_pred_mattes - alpha) ** 2).sum() / pixel sad_diff = np.abs(origin_pred_mattes - alpha).sum() mse_diffs += mse_diff sad_diffs += sad_diff print("sad:{} mse:{}".format(sad_diff, mse_diff)) origin_pred_mattes = (origin_pred_mattes * 255).astype(np.uint8) res = origin_pred_mattes.copy() # only attention unknown region res[trimap == 255] = 255 res[trimap == 0 ] = 0 if not os.path.exists(args.saveDir): os.makedirs(args.saveDir) cv2.imwrite(os.path.join(args.saveDir, img_info[0]), res) print("Avg-Cost: {} s/image".format((time.time() - t0) / cnt)) if args.alphaDir != '': print("Eval-MSE: {}".format(mse_diffs / cur)) print("Eval-SAD: {}".format(sad_diffs / cur)) if __name__ == "__main__": main()