mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
7.5 KiB
Python
195 lines
7.5 KiB
Python
import torch
|
|
import cv2
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
import logging
|
|
|
|
def gen_trimap(alpha):
|
|
k_size = random.choice(range(2, 5))
|
|
iterations = np.random.randint(5, 15)
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
|
|
dilated = cv2.dilate(alpha, kernel, iterations=iterations)
|
|
eroded = cv2.erode(alpha, kernel, iterations=iterations)
|
|
trimap = np.zeros(alpha.shape)
|
|
trimap.fill(128)
|
|
#trimap[alpha >= 255] = 255
|
|
trimap[eroded >= 255] = 255
|
|
trimap[dilated <= 0] = 0
|
|
|
|
'''
|
|
alpha_unknown = alpha[trimap == 128]
|
|
num_all = alpha_unknown.size
|
|
num_0 = (alpha_unknown == 0).sum()
|
|
num_1 = (alpha_unknown == 255).sum()
|
|
print("Debug: 0 : {}/{} {:.3f}".format(num_0, num_all, float(num_0)/num_all))
|
|
print("Debug: 255: {}/{} {:.3f}".format(num_1, num_all, float(num_1)/num_all))
|
|
'''
|
|
|
|
return trimap
|
|
|
|
|
|
def compute_gradient(img):
|
|
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
|
|
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
|
|
absX = cv2.convertScaleAbs(x)
|
|
absY = cv2.convertScaleAbs(y)
|
|
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
|
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
|
|
return grad
|
|
|
|
|
|
class MatTransform(object):
|
|
def __init__(self, flip=False):
|
|
self.flip = flip
|
|
|
|
def __call__(self, img, alpha, fg, bg, crop_h, crop_w):
|
|
h, w = alpha.shape
|
|
|
|
# trimap is dilated maybe choose some bg region(0)
|
|
# random crop in the unknown region center
|
|
target = np.where((alpha > 0) & (alpha < 255))
|
|
delta_h = center_h = crop_h / 2
|
|
delta_w = center_w = crop_w / 2
|
|
|
|
if len(target[0]) > 0:
|
|
rand_ind = np.random.randint(len(target[0]))
|
|
center_h = min(max(target[0][rand_ind], delta_h), h - delta_h)
|
|
center_w = min(max(target[1][rand_ind], delta_w), w - delta_w)
|
|
|
|
# choose unknown point as center not as left-top
|
|
start_h = int(center_h - delta_h)
|
|
start_w = int(center_w - delta_w)
|
|
end_h = int(center_h + delta_h)
|
|
end_w = int(center_w + delta_w)
|
|
|
|
#print("Debug: center({},{}) start({},{}) end({},{}) alpha:{} alpha-len:{} unknown-len:{}".format(center_h, center_w, start_h, start_w, end_h, end_w, alpha[int(center_h), int(center_w)], alpha.size, len(target[0])))
|
|
|
|
img = img [start_h : end_h, start_w : end_w]
|
|
fg = fg [start_h : end_h, start_w : end_w]
|
|
bg = bg [start_h : end_h, start_w : end_w]
|
|
alpha = alpha [start_h : end_h, start_w : end_w]
|
|
|
|
# random flip
|
|
if self.flip and random.random() < 0.5:
|
|
img = cv2.flip(img, 1)
|
|
alpha = cv2.flip(alpha, 1)
|
|
fg = cv2.flip(fg, 1)
|
|
bg = cv2.flip(bg, 1)
|
|
|
|
return img, alpha, fg, bg
|
|
|
|
|
|
def get_files(mydir):
|
|
res = []
|
|
for root, dirs, files in os.walk(mydir, followlinks=True):
|
|
for f in files:
|
|
if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg") or f.endswith(".JPG"):
|
|
res.append(os.path.join(root, f))
|
|
return res
|
|
|
|
|
|
# Dataset not composite online
|
|
class MatDatasetOffline(torch.utils.data.Dataset):
|
|
def __init__(self, args, transform=None, normalize=None):
|
|
self.samples=[]
|
|
self.transform = transform
|
|
self.normalize = normalize
|
|
self.args = args
|
|
self.size_h = args.size_h
|
|
self.size_w = args.size_w
|
|
self.crop_h = args.crop_h
|
|
self.crop_w = args.crop_w
|
|
self.logger = logging.getLogger("DeepImageMatting")
|
|
assert(len(self.crop_h) == len(self.crop_w))
|
|
|
|
fg_paths = get_files(self.args.fgDir)
|
|
|
|
self.cnt = len(fg_paths)
|
|
|
|
for fg_path in fg_paths:
|
|
alpha_path = fg_path.replace(self.args.fgDir, self.args.alphaDir)
|
|
img_path = fg_path.replace(self.args.fgDir, self.args.imgDir)
|
|
bg_path = fg_path.replace(self.args.fgDir, self.args.bgDir)
|
|
assert(os.path.exists(alpha_path))
|
|
assert(os.path.exists(fg_path))
|
|
assert(os.path.exists(bg_path))
|
|
assert(os.path.exists(img_path))
|
|
self.samples.append((alpha_path, fg_path, bg_path, img_path))
|
|
self.logger.info("MatDatasetOffline Samples: {}".format(self.cnt))
|
|
assert(self.cnt > 0)
|
|
|
|
def __getitem__(self,index):
|
|
alpha_path, fg_path, bg_path, img_path = self.samples[index]
|
|
|
|
img_info = [fg_path, alpha_path, bg_path, img_path]
|
|
|
|
# read fg, alpha
|
|
fg = cv2.imread(fg_path)[:, :, :3]
|
|
bg = cv2.imread(bg_path)[:, :, :3]
|
|
img = cv2.imread(img_path)[:, :, :3]
|
|
alpha = cv2.imread(alpha_path)[:, :, 0]
|
|
|
|
assert(bg.shape == fg.shape and bg.shape == img.shape)
|
|
img_info.append(fg.shape)
|
|
bh, bw, bc, = fg.shape
|
|
|
|
rand_ind = random.randint(0, len(self.crop_h) - 1)
|
|
cur_crop_h = self.crop_h[rand_ind]
|
|
cur_crop_w = self.crop_w[rand_ind]
|
|
|
|
# if ratio!=1: make the img (h==croph and w>=cropw)or(w==cropw and h>=croph)
|
|
wratio = float(cur_crop_w) / bw
|
|
hratio = float(cur_crop_h) / bh
|
|
ratio = wratio if wratio > hratio else hratio
|
|
if ratio > 1:
|
|
nbw = int(bw * ratio + 1.0)
|
|
nbh = int(bh * ratio + 1.0)
|
|
fg = cv2.resize(fg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
|
bg = cv2.resize(bg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
|
img = cv2.resize(img, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
|
alpha = cv2.resize(alpha, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
|
|
|
# random crop(crop_h, crop_w) and flip
|
|
if self.transform:
|
|
img, alpha, fg, bg = self.transform(img, alpha, fg, bg, cur_crop_h, cur_crop_w)
|
|
|
|
# resize to (size_h, size_w)
|
|
if self.size_h != img.shape[0] or self.size_w != img.shape[1]:
|
|
# resize
|
|
img =cv2.resize(img, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
|
fg =cv2.resize(fg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
|
bg =cv2.resize(bg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
|
alpha =cv2.resize(alpha, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
trimap = gen_trimap(alpha)
|
|
grad = compute_gradient(img)
|
|
|
|
if self.normalize:
|
|
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
# first, 0-255 to 0-1
|
|
# second, x-mean/std and HWC to CHW
|
|
img_norm = self.normalize(img_rgb)
|
|
else:
|
|
img_norm = None
|
|
|
|
#img_id = img_info[0].split('/')[-1]
|
|
#cv2.imwrite("result/debug/{}_img.png".format(img_id), img)
|
|
#cv2.imwrite("result/debug/{}_alpha.png".format(img_id), alpha)
|
|
#cv2.imwrite("result/debug/{}_fg.png".format(img_id), fg)
|
|
#cv2.imwrite("result/debug/{}_bg.png".format(img_id), bg)
|
|
#cv2.imwrite("result/debug/{}_trimap.png".format(img_id), trimap)
|
|
#cv2.imwrite("result/debug/{}_grad.png".format(img_id), grad)
|
|
alpha = torch.from_numpy(alpha.astype(np.float32)[np.newaxis, :, :])
|
|
trimap = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :])
|
|
grad = torch.from_numpy(grad.astype(np.float32)[np.newaxis, :, :])
|
|
img = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
|
|
fg = torch.from_numpy(fg.astype(np.float32)).permute(2, 0, 1)
|
|
bg = torch.from_numpy(bg.astype(np.float32)).permute(2, 0, 1)
|
|
|
|
return img, alpha, fg, bg, trimap, grad, img_norm, img_info
|
|
|
|
def __len__(self):
|
|
return len(self.samples)
|