augustUpdate

pull/30/head
Kritik Soman 4 years ago
parent 9ca028cf19
commit 06956277a7

Binary file not shown.

@ -0,0 +1,109 @@
import os
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
from gimpfu import *
import sys
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'pytorch-deep-image-matting'])
import torch
from argparse import Namespace
import net
import cv2
import os
import numpy as np
from deploy import inference_img_whole
def channelData(layer):#convert gimp image to numpy
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
pixChars=region[:,:] # Take whole layer
bpp=region.bpp
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
def createResultLayer(image,name,result):
rlBytes=np.uint8(result).tobytes();
rl=gimp.Layer(image,name,image.width,image.height,1,100,NORMAL_MODE)#image.active_layer.type or RGB_IMAGE
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
region[:,:]=rlBytes
image.add_layer(rl,0)
gimp.displays_flush()
def getnewalpha(image,mask):
if image.shape[2] == 4: # get rid of alpha channel
image = image[:,:,0:3]
if mask.shape[2] == 4: # get rid of alpha channel
mask = mask[:,:,0:3]
image = cv2.cvtColor(image,cv2.COLOR_RGB2BGR)
trimap = mask[:, :, 0]
cudaFlag = False
if torch.cuda.is_available():
cudaFlag = True
args = Namespace(crop_or_resize='whole', cuda=cudaFlag, max_size=1600, resume=baseLoc+'pytorch-deep-image-matting/''model/stage1_sad_57.1.pth', stage=1)
model = net.VGG16(args)
if cudaFlag:
ckpt = torch.load(args.resume)
else:
ckpt = torch.load(args.resume,map_location=torch.device("cpu"))
model.load_state_dict(ckpt['state_dict'], strict=True)
if cudaFlag:
model = model.cuda()
# ckpt = torch.load(args.resume)
# model.load_state_dict(ckpt['state_dict'], strict=True)
# model = model.cuda()
torch.cuda.empty_cache()
with torch.no_grad():
pred_mattes = inference_img_whole(args, model, image, trimap)
pred_mattes = (pred_mattes * 255).astype(np.uint8)
pred_mattes[trimap == 255] = 255
pred_mattes[trimap == 0] = 0
# pred_mattes = np.repeat(pred_mattes[:, :, np.newaxis], 3, axis=2)
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
pred_mattes = np.dstack((image,pred_mattes))
return pred_mattes
def deepmatting(imggimp, curlayer,layeri,layerm) :
if torch.cuda.is_available():
gimp.progress_init("(Using GPU) Running deep-matting for " + layeri.name + "...")
else:
gimp.progress_init("(Using CPU) Running deep-matting for " + layeri.name + "...")
img = channelData(layeri)
mask = channelData(layerm)
cpy=getnewalpha(img,mask)
createResultLayer(imggimp,'new_output',cpy)
register(
"deep-matting",
"deep-matting",
"Running image matting.",
"Kritik Soman",
"Your",
"2020",
"deepmatting...",
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
[ (PF_IMAGE, "image", "Input image", None),
(PF_DRAWABLE, "drawable", "Input drawable", None),
(PF_LAYER, "drawinglayer", "Original Image:", None),
(PF_LAYER, "drawinglayer", "Trimap Mask:", None)
],
[],
deepmatting, menu="<Image>/Layer/GIML-ML")
main()

@ -0,0 +1,76 @@
import os
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
from gimpfu import *
import sys
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools'])
import numpy as np
from scipy.cluster.vq import kmeans2
def channelData(layer):#convert gimp image to numpy
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
pixChars=region[:,:] # Take whole layer
bpp=region.bpp
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
def createResultLayer(image,name,result):
rlBytes=np.uint8(result).tobytes();
rl=gimp.Layer(image,name,image.width,image.height,0,100,NORMAL_MODE)#1 is for RGB with alpha
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
region[:,:]=rlBytes
image.add_layer(rl,0)
gimp.displays_flush()
def kmeans(imggimp, curlayer,layeri,n_clusters,locflag) :
image = channelData(layeri)
if image.shape[2] == 4: # get rid of alpha channel
image = image[:,:,0:3]
h,w,d = image.shape
# reshape the image to a 2D array of pixels and 3 color values (RGB)
pixel_values = image.reshape((-1, 3))
if locflag:
xx,yy = np.meshgrid(range(w),range(h))
x = xx.reshape(-1,1)
y = yy.reshape(-1,1)
pixel_values = np.concatenate((pixel_values,x,y),axis=1)
pixel_values = np.float32(pixel_values)
c,out = kmeans2(pixel_values,n_clusters)
if locflag:
c = np.uint8(c[:,0:3])
else:
c = np.uint8(c)
segmented_image = c[out.flatten()]
segmented_image = segmented_image.reshape((h,w,d))
createResultLayer(imggimp,'new_output',segmented_image)
register(
"kmeans",
"kmeans clustering",
"Running kmeans clustering.",
"Kritik Soman",
"Your",
"2020",
"kmeans...",
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
[ (PF_IMAGE, "image", "Input image", None),
(PF_DRAWABLE, "drawable", "Input drawable", None),
(PF_LAYER, "drawinglayer", "Original Image", None),
(PF_INT, "depth", "Number of clusters", 3),
(PF_BOOL, "position", "Use position", False)
],
[],
kmeans, menu="<Image>/Layer/GIML-ML")
main()

@ -1,8 +1,10 @@
unzip weights.zip
mkdir -p CelebAMask-HQ/MaskGAN_demo/checkpoints/label2face_512p
mkdir -p pytorch-SRResNet/model
mkdir -p pytorch-deep-image-matting/model
mkdir deeplabv3
mv weights/deepmatting/* pytorch-deep-image-matting/model
mv weights/colorize/* ideepcolor/models/pytorch/
mv weights/deblur/* DeblurGANv2/
mv weights/deeplabv3/* deeplabv3

@ -0,0 +1,194 @@
import torch
import cv2
import os
import random
import numpy as np
from torchvision import transforms
import logging
def gen_trimap(alpha):
k_size = random.choice(range(2, 5))
iterations = np.random.randint(5, 15)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
dilated = cv2.dilate(alpha, kernel, iterations=iterations)
eroded = cv2.erode(alpha, kernel, iterations=iterations)
trimap = np.zeros(alpha.shape)
trimap.fill(128)
#trimap[alpha >= 255] = 255
trimap[eroded >= 255] = 255
trimap[dilated <= 0] = 0
'''
alpha_unknown = alpha[trimap == 128]
num_all = alpha_unknown.size
num_0 = (alpha_unknown == 0).sum()
num_1 = (alpha_unknown == 255).sum()
print("Debug: 0 : {}/{} {:.3f}".format(num_0, num_all, float(num_0)/num_all))
print("Debug: 255: {}/{} {:.3f}".format(num_1, num_all, float(num_1)/num_all))
'''
return trimap
def compute_gradient(img):
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
absX = cv2.convertScaleAbs(x)
absY = cv2.convertScaleAbs(y)
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
return grad
class MatTransform(object):
def __init__(self, flip=False):
self.flip = flip
def __call__(self, img, alpha, fg, bg, crop_h, crop_w):
h, w = alpha.shape
# trimap is dilated maybe choose some bg region(0)
# random crop in the unknown region center
target = np.where((alpha > 0) & (alpha < 255))
delta_h = center_h = crop_h / 2
delta_w = center_w = crop_w / 2
if len(target[0]) > 0:
rand_ind = np.random.randint(len(target[0]))
center_h = min(max(target[0][rand_ind], delta_h), h - delta_h)
center_w = min(max(target[1][rand_ind], delta_w), w - delta_w)
# choose unknown point as center not as left-top
start_h = int(center_h - delta_h)
start_w = int(center_w - delta_w)
end_h = int(center_h + delta_h)
end_w = int(center_w + delta_w)
#print("Debug: center({},{}) start({},{}) end({},{}) alpha:{} alpha-len:{} unknown-len:{}".format(center_h, center_w, start_h, start_w, end_h, end_w, alpha[int(center_h), int(center_w)], alpha.size, len(target[0])))
img = img [start_h : end_h, start_w : end_w]
fg = fg [start_h : end_h, start_w : end_w]
bg = bg [start_h : end_h, start_w : end_w]
alpha = alpha [start_h : end_h, start_w : end_w]
# random flip
if self.flip and random.random() < 0.5:
img = cv2.flip(img, 1)
alpha = cv2.flip(alpha, 1)
fg = cv2.flip(fg, 1)
bg = cv2.flip(bg, 1)
return img, alpha, fg, bg
def get_files(mydir):
res = []
for root, dirs, files in os.walk(mydir, followlinks=True):
for f in files:
if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg") or f.endswith(".JPG"):
res.append(os.path.join(root, f))
return res
# Dataset not composite online
class MatDatasetOffline(torch.utils.data.Dataset):
def __init__(self, args, transform=None, normalize=None):
self.samples=[]
self.transform = transform
self.normalize = normalize
self.args = args
self.size_h = args.size_h
self.size_w = args.size_w
self.crop_h = args.crop_h
self.crop_w = args.crop_w
self.logger = logging.getLogger("DeepImageMatting")
assert(len(self.crop_h) == len(self.crop_w))
fg_paths = get_files(self.args.fgDir)
self.cnt = len(fg_paths)
for fg_path in fg_paths:
alpha_path = fg_path.replace(self.args.fgDir, self.args.alphaDir)
img_path = fg_path.replace(self.args.fgDir, self.args.imgDir)
bg_path = fg_path.replace(self.args.fgDir, self.args.bgDir)
assert(os.path.exists(alpha_path))
assert(os.path.exists(fg_path))
assert(os.path.exists(bg_path))
assert(os.path.exists(img_path))
self.samples.append((alpha_path, fg_path, bg_path, img_path))
self.logger.info("MatDatasetOffline Samples: {}".format(self.cnt))
assert(self.cnt > 0)
def __getitem__(self,index):
alpha_path, fg_path, bg_path, img_path = self.samples[index]
img_info = [fg_path, alpha_path, bg_path, img_path]
# read fg, alpha
fg = cv2.imread(fg_path)[:, :, :3]
bg = cv2.imread(bg_path)[:, :, :3]
img = cv2.imread(img_path)[:, :, :3]
alpha = cv2.imread(alpha_path)[:, :, 0]
assert(bg.shape == fg.shape and bg.shape == img.shape)
img_info.append(fg.shape)
bh, bw, bc, = fg.shape
rand_ind = random.randint(0, len(self.crop_h) - 1)
cur_crop_h = self.crop_h[rand_ind]
cur_crop_w = self.crop_w[rand_ind]
# if ratio!=1: make the img (h==croph and w>=cropw)or(w==cropw and h>=croph)
wratio = float(cur_crop_w) / bw
hratio = float(cur_crop_h) / bh
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
nbw = int(bw * ratio + 1.0)
nbh = int(bh * ratio + 1.0)
fg = cv2.resize(fg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
bg = cv2.resize(bg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
img = cv2.resize(img, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
alpha = cv2.resize(alpha, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
# random crop(crop_h, crop_w) and flip
if self.transform:
img, alpha, fg, bg = self.transform(img, alpha, fg, bg, cur_crop_h, cur_crop_w)
# resize to (size_h, size_w)
if self.size_h != img.shape[0] or self.size_w != img.shape[1]:
# resize
img =cv2.resize(img, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
fg =cv2.resize(fg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
bg =cv2.resize(bg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
alpha =cv2.resize(alpha, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
trimap = gen_trimap(alpha)
grad = compute_gradient(img)
if self.normalize:
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# first, 0-255 to 0-1
# second, x-mean/std and HWC to CHW
img_norm = self.normalize(img_rgb)
else:
img_norm = None
#img_id = img_info[0].split('/')[-1]
#cv2.imwrite("result/debug/{}_img.png".format(img_id), img)
#cv2.imwrite("result/debug/{}_alpha.png".format(img_id), alpha)
#cv2.imwrite("result/debug/{}_fg.png".format(img_id), fg)
#cv2.imwrite("result/debug/{}_bg.png".format(img_id), bg)
#cv2.imwrite("result/debug/{}_trimap.png".format(img_id), trimap)
#cv2.imwrite("result/debug/{}_grad.png".format(img_id), grad)
alpha = torch.from_numpy(alpha.astype(np.float32)[np.newaxis, :, :])
trimap = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :])
grad = torch.from_numpy(grad.astype(np.float32)[np.newaxis, :, :])
img = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
fg = torch.from_numpy(fg.astype(np.float32)).permute(2, 0, 1)
bg = torch.from_numpy(bg.astype(np.float32)).permute(2, 0, 1)
return img, alpha, fg, bg, trimap, grad, img_norm, img_info
def __len__(self):
return len(self.samples)

@ -0,0 +1,36 @@
import torch
from argparse import Namespace
import net
import cv2
import os
import numpy as np
from deploy import inference_img_whole
# input file list
image_path = "boy-1518482_1920_12_img.png"
trimap_path = "boy-1518482_1920_12.png"
image = cv2.imread(image_path)
trimap = cv2.imread(trimap_path)
# print(trimap.shape)
trimap = trimap[:, :, 0]
# init model
args = Namespace(crop_or_resize='whole', cuda=True, max_size=1600, resume='model/stage1_sad_57.1.pth', stage=1)
model = net.VGG16(args)
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['state_dict'], strict=True)
model = model.cuda()
torch.cuda.empty_cache()
with torch.no_grad():
pred_mattes = inference_img_whole(args, model, image, trimap)
pred_mattes = (pred_mattes * 255).astype(np.uint8)
pred_mattes[trimap == 255] = 255
pred_mattes[trimap == 0] = 0
# print(pred_mattes)
# cv2.imwrite('out.png', pred_mattes)
# import matplotlib.pyplot as plt
# plt.imshow(image)
# plt.show()

@ -0,0 +1,281 @@
import torch
import argparse
import torch.nn as nn
import net
import cv2
import os
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
import time
def get_args():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--size_h', type=int, default=320, help="height size of input image")
parser.add_argument('--size_w', type=int, default=320, help="width size of input image")
parser.add_argument('--imgDir', type=str, required=True, help="directory of image")
parser.add_argument('--trimapDir', type=str, required=True, help="directory of trimap")
parser.add_argument('--cuda', action='store_true', help='use cuda?')
parser.add_argument('--resume', type=str, required=True, help="checkpoint that model resume from")
parser.add_argument('--saveDir', type=str, required=True, help="where prediction result save to")
parser.add_argument('--alphaDir', type=str, default='', help="directory of gt")
parser.add_argument('--stage', type=int, required=True, choices=[0,1,2,3], help="backbone stage")
parser.add_argument('--not_strict', action='store_true', help='not copy ckpt strict?')
parser.add_argument('--crop_or_resize', type=str, default="whole", choices=["resize", "crop", "whole"], help="how manipulate image before test")
parser.add_argument('--max_size', type=int, default=1600, help="max size of test image")
args = parser.parse_args()
print(args)
return args
def gen_dataset(imgdir, trimapdir):
sample_set = []
img_ids = os.listdir(imgdir)
img_ids.sort()
cnt = len(img_ids)
cur = 1
for img_id in img_ids:
img_name = os.path.join(imgdir, img_id)
trimap_name = os.path.join(trimapdir, img_id)
assert(os.path.exists(img_name))
assert(os.path.exists(trimap_name))
sample_set.append((img_name, trimap_name))
return sample_set
def compute_gradient(img):
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
absX = cv2.convertScaleAbs(x)
absY = cv2.convertScaleAbs(y)
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
return grad
# inference once for image, return numpy
def inference_once(args, model, scale_img, scale_trimap, aligned=True):
if aligned:
assert(scale_img.shape[0] == args.size_h)
assert(scale_img.shape[1] == args.size_w)
normalize = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
])
scale_img_rgb = cv2.cvtColor(scale_img, cv2.COLOR_BGR2RGB)
# first, 0-255 to 0-1
# second, x-mean/std and HWC to CHW
tensor_img = normalize(scale_img_rgb).unsqueeze(0)
scale_grad = compute_gradient(scale_img)
#tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
tensor_trimap = torch.from_numpy(scale_trimap.astype(np.float32)[np.newaxis, np.newaxis, :, :])
tensor_grad = torch.from_numpy(scale_grad.astype(np.float32)[np.newaxis, np.newaxis, :, :])
if args.cuda:
tensor_img = tensor_img.cuda()
tensor_trimap = tensor_trimap.cuda()
tensor_grad = tensor_grad.cuda()
#print('Img Shape:{} Trimap Shape:{}'.format(img.shape, trimap.shape))
input_t = torch.cat((tensor_img, tensor_trimap / 255.), 1)
# forward
if args.stage <= 1:
# stage 1
pred_mattes, _ = model(input_t)
else:
# stage 2, 3
_, pred_mattes = model(input_t)
pred_mattes = pred_mattes.data
if args.cuda:
pred_mattes = pred_mattes.cpu()
pred_mattes = pred_mattes.numpy()[0, 0, :, :]
return pred_mattes
# forward for a full image by crop method
def inference_img_by_crop(args, model, img, trimap):
# crop the pictures, and forward one by one
h, w, c = img.shape
origin_pred_mattes = np.zeros((h, w), dtype=np.float32)
marks = np.zeros((h, w), dtype=np.float32)
for start_h in range(0, h, args.size_h):
end_h = start_h + args.size_h
for start_w in range(0, w, args.size_w):
end_w = start_w + args.size_w
crop_img = img[start_h: end_h, start_w: end_w, :]
crop_trimap = trimap[start_h: end_h, start_w: end_w]
crop_origin_h = crop_img.shape[0]
crop_origin_w = crop_img.shape[1]
#print("startH:{} startW:{} H:{} W:{}".format(start_h, start_w, crop_origin_h, crop_origin_w))
if len(np.where(crop_trimap == 128)[0]) <= 0:
continue
# egde patch in the right or bottom
if crop_origin_h != args.size_h or crop_origin_w != args.size_w:
crop_img = cv2.resize(crop_img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
crop_trimap = cv2.resize(crop_trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
# inference for each crop image patch
pred_mattes = inference_once(args, model, crop_img, crop_trimap)
if crop_origin_h != args.size_h or crop_origin_w != args.size_w:
pred_mattes = cv2.resize(pred_mattes, (crop_origin_w, crop_origin_h), interpolation=cv2.INTER_LINEAR)
origin_pred_mattes[start_h: end_h, start_w: end_w] += pred_mattes
marks[start_h: end_h, start_w: end_w] += 1
# smooth for overlap part
marks[marks <= 0] = 1.
origin_pred_mattes /= marks
return origin_pred_mattes
# forward for a full image by resize method
def inference_img_by_resize(args, model, img, trimap):
h, w, c = img.shape
# resize for network input, to Tensor
scale_img = cv2.resize(img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
scale_trimap = cv2.resize(trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
pred_mattes = inference_once(args, model, scale_img, scale_trimap)
# resize to origin size
origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR)
assert(origin_pred_mattes.shape == trimap.shape)
return origin_pred_mattes
# forward a whole image
def inference_img_whole(args, model, img, trimap):
h, w, c = img.shape
new_h = min(args.max_size, h - (h % 32))
new_w = min(args.max_size, w - (w % 32))
# resize for network input, to Tensor
scale_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
scale_trimap = cv2.resize(trimap, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
pred_mattes = inference_once(args, model, scale_img, scale_trimap, aligned=False)
# resize to origin size
origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR)
assert(origin_pred_mattes.shape == trimap.shape)
return origin_pred_mattes
def main():
print("===> Loading args")
args = get_args()
print("===> Environment init")
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if args.cuda and not torch.cuda.is_available():
raise Exception("No GPU found, please run without --cuda")
model = net.VGG16(args)
ckpt = torch.load(args.resume)
if args.not_strict:
model.load_state_dict(ckpt['state_dict'], strict=False)
else:
model.load_state_dict(ckpt['state_dict'], strict=True)
if args.cuda:
model = model.cuda()
print("===> Load dataset")
dataset = gen_dataset(args.imgDir, args.trimapDir)
mse_diffs = 0.
sad_diffs = 0.
cnt = len(dataset)
cur = 0
t0 = time.time()
for img_path, trimap_path in dataset:
img = cv2.imread(img_path)
trimap = cv2.imread(trimap_path)[:, :, 0]
assert(img.shape[:2] == trimap.shape[:2])
img_info = (img_path.split('/')[-1], img.shape[0], img.shape[1])
cur += 1
print('[{}/{}] {}'.format(cur, cnt, img_info[0]))
with torch.no_grad():
torch.cuda.empty_cache()
if args.crop_or_resize == "whole":
origin_pred_mattes = inference_img_whole(args, model, img, trimap)
elif args.crop_or_resize == "crop":
origin_pred_mattes = inference_img_by_crop(args, model, img, trimap)
else:
origin_pred_mattes = inference_img_by_resize(args, model, img, trimap)
# only attention unknown region
origin_pred_mattes[trimap == 255] = 1.
origin_pred_mattes[trimap == 0 ] = 0.
# origin trimap
pixel = float((trimap == 128).sum())
# eval if gt alpha is given
if args.alphaDir != '':
alpha_name = os.path.join(args.alphaDir, img_info[0])
assert(os.path.exists(alpha_name))
alpha = cv2.imread(alpha_name)[:, :, 0] / 255.
assert(alpha.shape == origin_pred_mattes.shape)
#x1 = (alpha[trimap == 255] == 1.0).sum() # x3
#x2 = (alpha[trimap == 0] == 0.0).sum() # x5
#x3 = (trimap == 255).sum()
#x4 = (trimap == 128).sum()
#x5 = (trimap == 0).sum()
#x6 = trimap.size # sum(x3,x4,x5)
#x7 = (alpha[trimap == 255] < 1.0).sum() # 0
#x8 = (alpha[trimap == 0] > 0).sum() #
#print(x1, x2, x3, x4, x5, x6, x7, x8)
#assert(x1 == x3)
#assert(x2 == x5)
#assert(x6 == x3 + x4 + x5)
#assert(x7 == 0)
#assert(x8 == 0)
mse_diff = ((origin_pred_mattes - alpha) ** 2).sum() / pixel
sad_diff = np.abs(origin_pred_mattes - alpha).sum()
mse_diffs += mse_diff
sad_diffs += sad_diff
print("sad:{} mse:{}".format(sad_diff, mse_diff))
origin_pred_mattes = (origin_pred_mattes * 255).astype(np.uint8)
res = origin_pred_mattes.copy()
# only attention unknown region
res[trimap == 255] = 255
res[trimap == 0 ] = 0
if not os.path.exists(args.saveDir):
os.makedirs(args.saveDir)
cv2.imwrite(os.path.join(args.saveDir, img_info[0]), res)
print("Avg-Cost: {} s/image".format((time.time() - t0) / cnt))
if args.alphaDir != '':
print("Eval-MSE: {}".format(mse_diffs / cur))
print("Eval-SAD: {}".format(sad_diffs / cur))
if __name__ == "__main__":
main()

@ -0,0 +1,126 @@
import torch
import torch.nn as nn
import math
import cv2
import torch.nn.functional as F
class VGG16(nn.Module):
def __init__(self, args):
super(VGG16, self).__init__()
self.stage = args.stage
self.conv1_1 = nn.Conv2d(4, 64, kernel_size=3,stride = 1, padding=1,bias=True)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3,stride = 1, padding=1,bias=True)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=True)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=True)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1,bias=True)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=True)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
# model released before 2019.09.09 should use kernel_size=1 & padding=0
#self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, padding=0,bias=True)
self.conv6_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
self.deconv6_1 = nn.Conv2d(512, 512, kernel_size=1,bias=True)
self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2,bias=True)
self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2,bias=True)
self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2,bias=True)
self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2,bias=True)
self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2,bias=True)
self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2,bias=True)
if args.stage == 2:
# for stage2 training
for p in self.parameters():
p.requires_grad=False
if self.stage == 2 or self.stage == 3:
self.refine_conv1 = nn.Conv2d(4, 64, kernel_size=3, padding=1, bias=True)
self.refine_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
self.refine_conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
self.refine_pred = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True)
def forward(self, x):
# Stage 1
x11 = F.relu(self.conv1_1(x))
x12 = F.relu(self.conv1_2(x11))
x1p, id1 = F.max_pool2d(x12,kernel_size=(2,2), stride=(2,2),return_indices=True)
# Stage 2
x21 = F.relu(self.conv2_1(x1p))
x22 = F.relu(self.conv2_2(x21))
x2p, id2 = F.max_pool2d(x22,kernel_size=(2,2), stride=(2,2),return_indices=True)
# Stage 3
x31 = F.relu(self.conv3_1(x2p))
x32 = F.relu(self.conv3_2(x31))
x33 = F.relu(self.conv3_3(x32))
x3p, id3 = F.max_pool2d(x33,kernel_size=(2,2), stride=(2,2),return_indices=True)
# Stage 4
x41 = F.relu(self.conv4_1(x3p))
x42 = F.relu(self.conv4_2(x41))
x43 = F.relu(self.conv4_3(x42))
x4p, id4 = F.max_pool2d(x43,kernel_size=(2,2), stride=(2,2),return_indices=True)
# Stage 5
x51 = F.relu(self.conv5_1(x4p))
x52 = F.relu(self.conv5_2(x51))
x53 = F.relu(self.conv5_3(x52))
x5p, id5 = F.max_pool2d(x53,kernel_size=(2,2), stride=(2,2),return_indices=True)
# Stage 6
x61 = F.relu(self.conv6_1(x5p))
# Stage 6d
x61d = F.relu(self.deconv6_1(x61))
# Stage 5d
x5d = F.max_unpool2d(x61d,id5, kernel_size=2, stride=2)
x51d = F.relu(self.deconv5_1(x5d))
# Stage 4d
x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
x41d = F.relu(self.deconv4_1(x4d))
# Stage 3d
x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
x31d = F.relu(self.deconv3_1(x3d))
# Stage 2d
x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
x21d = F.relu(self.deconv2_1(x2d))
# Stage 1d
x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
x12d = F.relu(self.deconv1_1(x1d))
# Should add sigmoid? github repo add so.
raw_alpha = self.deconv1(x12d)
pred_mattes = F.sigmoid(raw_alpha)
if self.stage <= 1:
return pred_mattes, 0
# Stage2 refine conv1
refine0 = torch.cat((x[:, :3, :, :], pred_mattes), 1)
refine1 = F.relu(self.refine_conv1(refine0))
refine2 = F.relu(self.refine_conv2(refine1))
refine3 = F.relu(self.refine_conv3(refine2))
# Should add sigmoid?
# sigmoid lead to refine result all converge to 0...
#pred_refine = F.sigmoid(self.refine_pred(refine3))
pred_refine = self.refine_pred(refine3)
pred_alpha = F.sigmoid(raw_alpha + pred_refine)
#print(pred_mattes.mean(), pred_alpha.mean(), pred_refine.sum())
return pred_mattes, pred_alpha

@ -0,0 +1,63 @@
import torch
import torchvision
import collections
import os
HOME = os.environ['HOME']
model_path = "{}/.torch/models/vgg16-397923af.pth".format(HOME)
#model_path = "/data/liuliang/deep_image_matting/train/vgg16-397923af.pth"
if not os.path.exists(model_path):
model = torchvision.models.vgg16(pretrained=True)
assert(os.path.exists(model_path))
x = torch.load(model_path)
val = collections.OrderedDict()
val['conv1_1.weight'] = torch.cat((x['features.0.weight'], torch.zeros(64, 1, 3, 3)), 1)
replace = { u'features.0.bias' : 'conv1_1.bias',
u'features.2.weight' : 'conv1_2.weight',
u'features.2.bias' : 'conv1_2.bias',
u'features.5.weight' : 'conv2_1.weight',
u'features.5.bias' : 'conv2_1.bias',
u'features.7.weight' : 'conv2_2.weight',
u'features.7.bias' : 'conv2_2.bias',
u'features.10.weight': 'conv3_1.weight',
u'features.10.bias' : 'conv3_1.bias',
u'features.12.weight': 'conv3_2.weight',
u'features.12.bias' : 'conv3_2.bias',
u'features.14.weight': 'conv3_3.weight',
u'features.14.bias' : 'conv3_3.bias',
u'features.17.weight': 'conv4_1.weight',
u'features.17.bias' : 'conv4_1.bias',
u'features.19.weight': 'conv4_2.weight',
u'features.19.bias' : 'conv4_2.bias',
u'features.21.weight': 'conv4_3.weight',
u'features.21.bias' : 'conv4_3.bias',
u'features.24.weight': 'conv5_1.weight',
u'features.24.bias' : 'conv5_1.bias',
u'features.26.weight': 'conv5_2.weight',
u'features.26.bias' : 'conv5_2.bias',
u'features.28.weight': 'conv5_3.weight',
u'features.28.bias' : 'conv5_3.bias'
}
#print(x['classifier.0.weight'].shape)
#print(x['classifier.0.bias'].shape)
#tmp1 = x['classifier.0.weight'].reshape(4096, 512, 7, 7)
#print(tmp1.shape)
#val['conv6_1.weight'] = tmp1[:512, :, :, :]
#val['conv6_1.bias'] = x['classifier.0.bias']
for key in replace.keys():
print(key, replace[key])
val[replace[key]] = x[key]
y = {}
y['state_dict'] = val
y['epoch'] = 0
if not os.path.exists('./model'):
os.makedirs('./model')
torch.save(y, './model/vgg_state_dict.pth')

@ -0,0 +1,137 @@
# composite image with dataset from "deep image matting"
import os
import cv2
import math
import time
import shutil
root_dir = "/home/liuliang/Downloads/Combined_Dataset"
test_bg_dir = '/home/liuliang/Desktop/dataset/matting/VOCdevkit/VOC2012/JPEGImages'
train_bg_dir = '/home/liuliang/Desktop/dataset/matting/mscoco/train2017'
def my_composite(fg_names, bg_names, fg_dir, alpha_dir, bg_dir, num_bg, comp_dir):
fg_ids = open(fg_names).readlines()
bg_ids = open(bg_names).readlines()
fg_cnt = len(fg_ids)
bg_cnt = len(bg_ids)
print(fg_cnt, bg_cnt)
assert(fg_cnt * num_bg == bg_cnt)
for i in range(fg_cnt):
im_name = fg_ids[i].strip("\n").strip("\r")
fg_path = os.path.join(fg_dir, im_name)
alpha_path = os.path.join(alpha_dir, im_name)
#print(fg_path, alpha_path)
assert(os.path.exists(fg_path))
assert(os.path.exists(alpha_path))
fg = cv2.imread(fg_path)
alpha = cv2.imread(alpha_path)
#print("alpha shape:", alpha.shape, "image shape:", fg.shape)
assert(alpha.shape == fg.shape)
h, w ,c = fg.shape
base = i * num_bg
for bcount in range(num_bg):
bg_path = os.path.join(bg_dir, bg_ids[base + bcount].strip("\n").strip("\r"))
print(base + bcount, fg_path, bg_path)
assert(os.path.exists(bg_path))
bg = cv2.imread(bg_path)
bh, bw, bc = bg.shape
wratio = float(w) / bw
hratio = float(h) / bh
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
new_bw = int(bw * ratio + 1.0)
new_bh = int(bh * ratio + 1.0)
bg = cv2.resize(bg, (new_bw, new_bh), interpolation=cv2.INTER_LINEAR)
bg = bg[0 : h, 0 : w, :]
#print(bg.shape)
assert(bg.shape == fg.shape)
alpha_f = alpha / 255.
comp = fg * alpha_f + bg * (1. - alpha_f)
img_save_id = im_name[:len(im_name)-4] + '_' + str(bcount) + '.png'
comp_save_path = os.path.join(comp_dir, "image/" + img_save_id)
fg_save_path = os.path.join(comp_dir, "fg/" + img_save_id)
bg_save_path = os.path.join(comp_dir, "bg/" + img_save_id)
alpha_save_path = os.path.join(comp_dir, "alpha/" + img_save_id)
cv2.imwrite(comp_save_path, comp)
cv2.imwrite(fg_save_path, fg)
cv2.imwrite(bg_save_path, bg)
cv2.imwrite(alpha_save_path, alpha)
def copy_dir2dir(src_dir, des_dir):
for img_id in os.listdir(src_dir):
shutil.copyfile(os.path.join(src_dir, img_id), os.path.join(des_dir, img_id))
def main():
test_num_bg = 20
test_fg_names = os.path.join(root_dir, "Test_set/test_fg_names.txt")
test_bg_names = os.path.join(root_dir, "Test_set/test_bg_names.txt")
test_fg_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/fg")
test_alpha_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/alpha")
test_trimap_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/trimaps")
test_comp_dir = os.path.join(root_dir, "Test_set/comp")
train_num_bg = 100
train_fg_names = os.path.join(root_dir, "Training_set/training_fg_names.txt")
train_bg_names_coco2014 = os.path.join(root_dir, "Training_set/training_bg_names.txt")
train_bg_names_coco2017 = os.path.join(root_dir, "Training_set/training_bg_names_coco2017.txt")
train_fg_dir = os.path.join(root_dir, "Training_set/all/fg")
train_alpha_dir = os.path.join(root_dir, "Training_set/all/alpha")
train_comp_dir = os.path.join(root_dir, "Training_set/comp")
# change the bg names formate if is coco 2017
fin = open(train_bg_names_coco2014, 'r')
fout = open(train_bg_names_coco2017, 'w')
lls = fin.readlines()
for l in lls:
fout.write(l[15:])
fin.close()
fout.close()
if not os.path.exists(test_comp_dir):
os.makedirs(test_comp_dir + '/image')
os.makedirs(test_comp_dir + '/fg')
os.makedirs(test_comp_dir + '/bg')
os.makedirs(test_comp_dir + '/alpha')
os.makedirs(test_comp_dir + '/trimap')
if not os.path.exists(train_comp_dir):
os.makedirs(train_comp_dir + '/image')
os.makedirs(train_comp_dir + '/fg')
os.makedirs(train_comp_dir + '/bg')
os.makedirs(train_comp_dir + '/alpha')
if not os.path.exists(train_alpha_dir):
os.makedirs(train_alpha_dir)
if not os.path.exists(train_fg_dir):
os.makedirs(train_fg_dir)
# copy test trimaps
copy_dir2dir(test_trimap_dir, test_comp_dir + '/trimap')
# copy train images together
copy_dir2dir(os.path.join(root_dir, "Training_set/Adobe-licensed images/alpha"), train_alpha_dir)
copy_dir2dir(os.path.join(root_dir, "Training_set/Adobe-licensed images/fg"), train_fg_dir)
copy_dir2dir(os.path.join(root_dir, "Training_set/Other/alpha"), train_alpha_dir)
copy_dir2dir(os.path.join(root_dir, "Training_set/Other/fg"), train_fg_dir)
# composite test image
my_composite(test_fg_names, test_bg_names, test_fg_dir, test_alpha_dir, test_bg_dir, test_num_bg, test_comp_dir)
# composite train image
my_composite(train_fg_names, train_bg_names_coco2017, train_fg_dir, train_alpha_dir, train_bg_dir, train_num_bg, train_comp_dir)
if __name__ == "__main__":
main()

@ -0,0 +1,50 @@
import re
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
# args: log_name, match_rule, self_log_interval, smooth_log_interation
loss_file_name = "simple_loss"
title = "{}_Loss".format(loss_file_name)
f = open("../log/{}.log".format(loss_file_name))
pattern = re.compile(r'Loss:[ ]*\d+\.\d+')
self_inter = 10
smooth = 20
# read log file
lines = f.readlines()
print("Line: {}".format(len(lines)))
ys = []
k = 0
cnt = 0
sum_y = 0.
# read one by one
for line in lines:
obj = re.search(pattern, line)
if obj:
val = float(obj.group().split(':')[-1])
sum_y += val
k += 1
if k >= smooth:
ys.append(sum_y / k)
sum_y = 0.
k = 0
cnt += 1
if cnt % 10 == 0:
print("ys cnt: {}".format(cnt))
if k > 0:
ys.append(sum_y / k)
ys = np.array(ys)
xs = np.arange(len(ys)) * self_inter * smooth
print(xs)
print(ys)
plt.plot(xs, ys)
plt.title(title)
plt.xlabel("Iter")
plt.ylabel("Loss")
plt.savefig("../log/{}.png".format(title))
plt.show()
Loading…
Cancel
Save