You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GIMP-ML/gimp-plugins/pytorch-deep-image-matting/data.py

195 lines
7.5 KiB
Python

import torch
import cv2
import os
import random
import numpy as np
from torchvision import transforms
import logging
def gen_trimap(alpha):
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
dilated = cv2.dilate(alpha, kernel, iterations=iterations)
eroded = cv2.erode(alpha, kernel, iterations=iterations)
trimap = np.zeros(alpha.shape)
trimap.fill(128)
#trimap[alpha >= 255] = 255
trimap[eroded >= 255] = 255
trimap[dilated <= 0] = 0
'''
alpha_unknown = alpha[trimap == 128]
num_all = alpha_unknown.size
num_0 = (alpha_unknown == 0).sum()
num_1 = (alpha_unknown == 255).sum()
print("Debug: 0 : {}/{} {:.3f}".format(num_0, num_all, float(num_0)/num_all))
print("Debug: 255: {}/{} {:.3f}".format(num_1, num_all, float(num_1)/num_all))
'''
return trimap
def compute_gradient(img):
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
absX = cv2.convertScaleAbs(x)
absY = cv2.convertScaleAbs(y)
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
return grad
class MatTransform(object):
def __init__(self, flip=False):
self.flip = flip
def __call__(self, img, alpha, fg, bg, crop_h, crop_w):
h, w = alpha.shape
# trimap is dilated maybe choose some bg region(0)
# random crop in the unknown region center
target = np.where((alpha > 0) & (alpha < 255))
delta_h = center_h = crop_h / 2
delta_w = center_w = crop_w / 2
if len(target[0]) > 0:
rand_ind = np.random.randint(len(target[0]))
center_h = min(max(target[0][rand_ind], delta_h), h - delta_h)
center_w = min(max(target[1][rand_ind], delta_w), w - delta_w)
# choose unknown point as center not as left-top
start_h = int(center_h - delta_h)
start_w = int(center_w - delta_w)
end_h = int(center_h + delta_h)
end_w = int(center_w + delta_w)
#print("Debug: center({},{}) start({},{}) end({},{}) alpha:{} alpha-len:{} unknown-len:{}".format(center_h, center_w, start_h, start_w, end_h, end_w, alpha[int(center_h), int(center_w)], alpha.size, len(target[0])))
img = img [start_h : end_h, start_w : end_w]
fg = fg [start_h : end_h, start_w : end_w]
bg = bg [start_h : end_h, start_w : end_w]
alpha = alpha [start_h : end_h, start_w : end_w]
# random flip
if self.flip and random.random() < 0.5:
img = cv2.flip(img, 1)
alpha = cv2.flip(alpha, 1)
fg = cv2.flip(fg, 1)
bg = cv2.flip(bg, 1)
return img, alpha, fg, bg
def get_files(mydir):
res = []
for root, dirs, files in os.walk(mydir, followlinks=True):
for f in files:
if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg") or f.endswith(".JPG"):
res.append(os.path.join(root, f))
return res
# Dataset not composite online
class MatDatasetOffline(torch.utils.data.Dataset):
def __init__(self, args, transform=None, normalize=None):
self.samples=[]
self.transform = transform
self.normalize = normalize
self.args = args
self.size_h = args.size_h
self.size_w = args.size_w
self.crop_h = args.crop_h
self.crop_w = args.crop_w
self.logger = logging.getLogger("DeepImageMatting")
assert(len(self.crop_h) == len(self.crop_w))
fg_paths = get_files(self.args.fgDir)
self.cnt = len(fg_paths)
for fg_path in fg_paths:
alpha_path = fg_path.replace(self.args.fgDir, self.args.alphaDir)
img_path = fg_path.replace(self.args.fgDir, self.args.imgDir)
bg_path = fg_path.replace(self.args.fgDir, self.args.bgDir)
assert(os.path.exists(alpha_path))
assert(os.path.exists(fg_path))
assert(os.path.exists(bg_path))
assert(os.path.exists(img_path))
self.samples.append((alpha_path, fg_path, bg_path, img_path))
self.logger.info("MatDatasetOffline Samples: {}".format(self.cnt))
assert(self.cnt > 0)
def __getitem__(self,index):
alpha_path, fg_path, bg_path, img_path = self.samples[index]
img_info = [fg_path, alpha_path, bg_path, img_path]
# read fg, alpha
fg = cv2.imread(fg_path)[:, :, :3]
bg = cv2.imread(bg_path)[:, :, :3]
img = cv2.imread(img_path)[:, :, :3]
alpha = cv2.imread(alpha_path)[:, :, 0]
assert(bg.shape == fg.shape and bg.shape == img.shape)
img_info.append(fg.shape)
bh, bw, bc, = fg.shape
rand_ind = random.randint(0, len(self.crop_h) - 1)
cur_crop_h = self.crop_h[rand_ind]
cur_crop_w = self.crop_w[rand_ind]
# if ratio!=1: make the img (h==croph and w>=cropw)or(w==cropw and h>=croph)
wratio = float(cur_crop_w) / bw
hratio = float(cur_crop_h) / bh
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
nbw = int(bw * ratio + 1.0)
nbh = int(bh * ratio + 1.0)
fg = cv2.resize(fg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
bg = cv2.resize(bg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
img = cv2.resize(img, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
alpha = cv2.resize(alpha, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
# random crop(crop_h, crop_w) and flip
if self.transform:
img, alpha, fg, bg = self.transform(img, alpha, fg, bg, cur_crop_h, cur_crop_w)
# resize to (size_h, size_w)
if self.size_h != img.shape[0] or self.size_w != img.shape[1]:
# resize
img =cv2.resize(img, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
fg =cv2.resize(fg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
bg =cv2.resize(bg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
alpha =cv2.resize(alpha, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
trimap = gen_trimap(alpha)
grad = compute_gradient(img)
if self.normalize:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# first, 0-255 to 0-1
# second, x-mean/std and HWC to CHW
img_norm = self.normalize(img_rgb)
else:
img_norm = None
#img_id = img_info[0].split('/')[-1]
#cv2.imwrite("result/debug/{}_img.png".format(img_id), img)
#cv2.imwrite("result/debug/{}_alpha.png".format(img_id), alpha)
#cv2.imwrite("result/debug/{}_fg.png".format(img_id), fg)
#cv2.imwrite("result/debug/{}_bg.png".format(img_id), bg)
#cv2.imwrite("result/debug/{}_trimap.png".format(img_id), trimap)
#cv2.imwrite("result/debug/{}_grad.png".format(img_id), grad)
alpha = torch.from_numpy(alpha.astype(np.float32)[np.newaxis, :, :])
trimap = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :])
grad = torch.from_numpy(grad.astype(np.float32)[np.newaxis, :, :])
img = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
fg = torch.from_numpy(fg.astype(np.float32)).permute(2, 0, 1)
bg = torch.from_numpy(bg.astype(np.float32)).permute(2, 0, 1)
return img, alpha, fg, bg, trimap, grad, img_norm, img_info
def __len__(self):
return len(self.samples)