mirror of https://github.com/kritiksoman/GIMP-ML
augustUpdate
parent
9ca028cf19
commit
06956277a7
Binary file not shown.
@ -0,0 +1,109 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
|
||||
|
||||
from gimpfu import *
|
||||
import sys
|
||||
|
||||
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'pytorch-deep-image-matting'])
|
||||
|
||||
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
import net
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from deploy import inference_img_whole
|
||||
|
||||
|
||||
|
||||
def channelData(layer):#convert gimp image to numpy
|
||||
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
|
||||
pixChars=region[:,:] # Take whole layer
|
||||
bpp=region.bpp
|
||||
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
|
||||
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
|
||||
|
||||
|
||||
def createResultLayer(image,name,result):
|
||||
rlBytes=np.uint8(result).tobytes();
|
||||
rl=gimp.Layer(image,name,image.width,image.height,1,100,NORMAL_MODE)#image.active_layer.type or RGB_IMAGE
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
image.add_layer(rl,0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def getnewalpha(image,mask):
|
||||
if image.shape[2] == 4: # get rid of alpha channel
|
||||
image = image[:,:,0:3]
|
||||
if mask.shape[2] == 4: # get rid of alpha channel
|
||||
mask = mask[:,:,0:3]
|
||||
|
||||
image = cv2.cvtColor(image,cv2.COLOR_RGB2BGR)
|
||||
trimap = mask[:, :, 0]
|
||||
|
||||
cudaFlag = False
|
||||
if torch.cuda.is_available():
|
||||
cudaFlag = True
|
||||
|
||||
args = Namespace(crop_or_resize='whole', cuda=cudaFlag, max_size=1600, resume=baseLoc+'pytorch-deep-image-matting/''model/stage1_sad_57.1.pth', stage=1)
|
||||
model = net.VGG16(args)
|
||||
|
||||
if cudaFlag:
|
||||
ckpt = torch.load(args.resume)
|
||||
else:
|
||||
ckpt = torch.load(args.resume,map_location=torch.device("cpu"))
|
||||
model.load_state_dict(ckpt['state_dict'], strict=True)
|
||||
if cudaFlag:
|
||||
model = model.cuda()
|
||||
|
||||
# ckpt = torch.load(args.resume)
|
||||
# model.load_state_dict(ckpt['state_dict'], strict=True)
|
||||
# model = model.cuda()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
pred_mattes = inference_img_whole(args, model, image, trimap)
|
||||
pred_mattes = (pred_mattes * 255).astype(np.uint8)
|
||||
pred_mattes[trimap == 255] = 255
|
||||
pred_mattes[trimap == 0] = 0
|
||||
# pred_mattes = np.repeat(pred_mattes[:, :, np.newaxis], 3, axis=2)
|
||||
|
||||
image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
|
||||
pred_mattes = np.dstack((image,pred_mattes))
|
||||
return pred_mattes
|
||||
|
||||
|
||||
def deepmatting(imggimp, curlayer,layeri,layerm) :
|
||||
if torch.cuda.is_available():
|
||||
gimp.progress_init("(Using GPU) Running deep-matting for " + layeri.name + "...")
|
||||
else:
|
||||
gimp.progress_init("(Using CPU) Running deep-matting for " + layeri.name + "...")
|
||||
|
||||
img = channelData(layeri)
|
||||
mask = channelData(layerm)
|
||||
|
||||
cpy=getnewalpha(img,mask)
|
||||
createResultLayer(imggimp,'new_output',cpy)
|
||||
|
||||
|
||||
|
||||
register(
|
||||
"deep-matting",
|
||||
"deep-matting",
|
||||
"Running image matting.",
|
||||
"Kritik Soman",
|
||||
"Your",
|
||||
"2020",
|
||||
"deepmatting...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[ (PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
(PF_LAYER, "drawinglayer", "Original Image:", None),
|
||||
(PF_LAYER, "drawinglayer", "Trimap Mask:", None)
|
||||
],
|
||||
[],
|
||||
deepmatting, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
main()
|
@ -0,0 +1,76 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
|
||||
|
||||
from gimpfu import *
|
||||
import sys
|
||||
|
||||
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools'])
|
||||
|
||||
import numpy as np
|
||||
from scipy.cluster.vq import kmeans2
|
||||
|
||||
|
||||
def channelData(layer):#convert gimp image to numpy
|
||||
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
|
||||
pixChars=region[:,:] # Take whole layer
|
||||
bpp=region.bpp
|
||||
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
|
||||
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
|
||||
|
||||
|
||||
def createResultLayer(image,name,result):
|
||||
rlBytes=np.uint8(result).tobytes();
|
||||
rl=gimp.Layer(image,name,image.width,image.height,0,100,NORMAL_MODE)#1 is for RGB with alpha
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
image.add_layer(rl,0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def kmeans(imggimp, curlayer,layeri,n_clusters,locflag) :
|
||||
image = channelData(layeri)
|
||||
if image.shape[2] == 4: # get rid of alpha channel
|
||||
image = image[:,:,0:3]
|
||||
h,w,d = image.shape
|
||||
# reshape the image to a 2D array of pixels and 3 color values (RGB)
|
||||
pixel_values = image.reshape((-1, 3))
|
||||
|
||||
if locflag:
|
||||
xx,yy = np.meshgrid(range(w),range(h))
|
||||
x = xx.reshape(-1,1)
|
||||
y = yy.reshape(-1,1)
|
||||
pixel_values = np.concatenate((pixel_values,x,y),axis=1)
|
||||
|
||||
pixel_values = np.float32(pixel_values)
|
||||
c,out = kmeans2(pixel_values,n_clusters)
|
||||
|
||||
if locflag:
|
||||
c = np.uint8(c[:,0:3])
|
||||
else:
|
||||
c = np.uint8(c)
|
||||
|
||||
segmented_image = c[out.flatten()]
|
||||
segmented_image = segmented_image.reshape((h,w,d))
|
||||
createResultLayer(imggimp,'new_output',segmented_image)
|
||||
|
||||
|
||||
register(
|
||||
"kmeans",
|
||||
"kmeans clustering",
|
||||
"Running kmeans clustering.",
|
||||
"Kritik Soman",
|
||||
"Your",
|
||||
"2020",
|
||||
"kmeans...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[ (PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
(PF_LAYER, "drawinglayer", "Original Image", None),
|
||||
(PF_INT, "depth", "Number of clusters", 3),
|
||||
(PF_BOOL, "position", "Use position", False)
|
||||
|
||||
],
|
||||
[],
|
||||
kmeans, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
main()
|
Binary file not shown.
@ -0,0 +1,194 @@
|
||||
import torch
|
||||
import cv2
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
import logging
|
||||
|
||||
def gen_trimap(alpha):
|
||||
k_size = random.choice(range(2, 5))
|
||||
iterations = np.random.randint(5, 15)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k_size, k_size))
|
||||
dilated = cv2.dilate(alpha, kernel, iterations=iterations)
|
||||
eroded = cv2.erode(alpha, kernel, iterations=iterations)
|
||||
trimap = np.zeros(alpha.shape)
|
||||
trimap.fill(128)
|
||||
#trimap[alpha >= 255] = 255
|
||||
trimap[eroded >= 255] = 255
|
||||
trimap[dilated <= 0] = 0
|
||||
|
||||
'''
|
||||
alpha_unknown = alpha[trimap == 128]
|
||||
num_all = alpha_unknown.size
|
||||
num_0 = (alpha_unknown == 0).sum()
|
||||
num_1 = (alpha_unknown == 255).sum()
|
||||
print("Debug: 0 : {}/{} {:.3f}".format(num_0, num_all, float(num_0)/num_all))
|
||||
print("Debug: 255: {}/{} {:.3f}".format(num_1, num_all, float(num_1)/num_all))
|
||||
'''
|
||||
|
||||
return trimap
|
||||
|
||||
|
||||
def compute_gradient(img):
|
||||
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
|
||||
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
|
||||
absX = cv2.convertScaleAbs(x)
|
||||
absY = cv2.convertScaleAbs(y)
|
||||
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
||||
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
|
||||
return grad
|
||||
|
||||
|
||||
class MatTransform(object):
|
||||
def __init__(self, flip=False):
|
||||
self.flip = flip
|
||||
|
||||
def __call__(self, img, alpha, fg, bg, crop_h, crop_w):
|
||||
h, w = alpha.shape
|
||||
|
||||
# trimap is dilated maybe choose some bg region(0)
|
||||
# random crop in the unknown region center
|
||||
target = np.where((alpha > 0) & (alpha < 255))
|
||||
delta_h = center_h = crop_h / 2
|
||||
delta_w = center_w = crop_w / 2
|
||||
|
||||
if len(target[0]) > 0:
|
||||
rand_ind = np.random.randint(len(target[0]))
|
||||
center_h = min(max(target[0][rand_ind], delta_h), h - delta_h)
|
||||
center_w = min(max(target[1][rand_ind], delta_w), w - delta_w)
|
||||
|
||||
# choose unknown point as center not as left-top
|
||||
start_h = int(center_h - delta_h)
|
||||
start_w = int(center_w - delta_w)
|
||||
end_h = int(center_h + delta_h)
|
||||
end_w = int(center_w + delta_w)
|
||||
|
||||
#print("Debug: center({},{}) start({},{}) end({},{}) alpha:{} alpha-len:{} unknown-len:{}".format(center_h, center_w, start_h, start_w, end_h, end_w, alpha[int(center_h), int(center_w)], alpha.size, len(target[0])))
|
||||
|
||||
img = img [start_h : end_h, start_w : end_w]
|
||||
fg = fg [start_h : end_h, start_w : end_w]
|
||||
bg = bg [start_h : end_h, start_w : end_w]
|
||||
alpha = alpha [start_h : end_h, start_w : end_w]
|
||||
|
||||
# random flip
|
||||
if self.flip and random.random() < 0.5:
|
||||
img = cv2.flip(img, 1)
|
||||
alpha = cv2.flip(alpha, 1)
|
||||
fg = cv2.flip(fg, 1)
|
||||
bg = cv2.flip(bg, 1)
|
||||
|
||||
return img, alpha, fg, bg
|
||||
|
||||
|
||||
def get_files(mydir):
|
||||
res = []
|
||||
for root, dirs, files in os.walk(mydir, followlinks=True):
|
||||
for f in files:
|
||||
if f.endswith(".jpg") or f.endswith(".png") or f.endswith(".jpeg") or f.endswith(".JPG"):
|
||||
res.append(os.path.join(root, f))
|
||||
return res
|
||||
|
||||
|
||||
# Dataset not composite online
|
||||
class MatDatasetOffline(torch.utils.data.Dataset):
|
||||
def __init__(self, args, transform=None, normalize=None):
|
||||
self.samples=[]
|
||||
self.transform = transform
|
||||
self.normalize = normalize
|
||||
self.args = args
|
||||
self.size_h = args.size_h
|
||||
self.size_w = args.size_w
|
||||
self.crop_h = args.crop_h
|
||||
self.crop_w = args.crop_w
|
||||
self.logger = logging.getLogger("DeepImageMatting")
|
||||
assert(len(self.crop_h) == len(self.crop_w))
|
||||
|
||||
fg_paths = get_files(self.args.fgDir)
|
||||
|
||||
self.cnt = len(fg_paths)
|
||||
|
||||
for fg_path in fg_paths:
|
||||
alpha_path = fg_path.replace(self.args.fgDir, self.args.alphaDir)
|
||||
img_path = fg_path.replace(self.args.fgDir, self.args.imgDir)
|
||||
bg_path = fg_path.replace(self.args.fgDir, self.args.bgDir)
|
||||
assert(os.path.exists(alpha_path))
|
||||
assert(os.path.exists(fg_path))
|
||||
assert(os.path.exists(bg_path))
|
||||
assert(os.path.exists(img_path))
|
||||
self.samples.append((alpha_path, fg_path, bg_path, img_path))
|
||||
self.logger.info("MatDatasetOffline Samples: {}".format(self.cnt))
|
||||
assert(self.cnt > 0)
|
||||
|
||||
def __getitem__(self,index):
|
||||
alpha_path, fg_path, bg_path, img_path = self.samples[index]
|
||||
|
||||
img_info = [fg_path, alpha_path, bg_path, img_path]
|
||||
|
||||
# read fg, alpha
|
||||
fg = cv2.imread(fg_path)[:, :, :3]
|
||||
bg = cv2.imread(bg_path)[:, :, :3]
|
||||
img = cv2.imread(img_path)[:, :, :3]
|
||||
alpha = cv2.imread(alpha_path)[:, :, 0]
|
||||
|
||||
assert(bg.shape == fg.shape and bg.shape == img.shape)
|
||||
img_info.append(fg.shape)
|
||||
bh, bw, bc, = fg.shape
|
||||
|
||||
rand_ind = random.randint(0, len(self.crop_h) - 1)
|
||||
cur_crop_h = self.crop_h[rand_ind]
|
||||
cur_crop_w = self.crop_w[rand_ind]
|
||||
|
||||
# if ratio!=1: make the img (h==croph and w>=cropw)or(w==cropw and h>=croph)
|
||||
wratio = float(cur_crop_w) / bw
|
||||
hratio = float(cur_crop_h) / bh
|
||||
ratio = wratio if wratio > hratio else hratio
|
||||
if ratio > 1:
|
||||
nbw = int(bw * ratio + 1.0)
|
||||
nbh = int(bh * ratio + 1.0)
|
||||
fg = cv2.resize(fg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
||||
bg = cv2.resize(bg, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
||||
img = cv2.resize(img, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
||||
alpha = cv2.resize(alpha, (nbw, nbh), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# random crop(crop_h, crop_w) and flip
|
||||
if self.transform:
|
||||
img, alpha, fg, bg = self.transform(img, alpha, fg, bg, cur_crop_h, cur_crop_w)
|
||||
|
||||
# resize to (size_h, size_w)
|
||||
if self.size_h != img.shape[0] or self.size_w != img.shape[1]:
|
||||
# resize
|
||||
img =cv2.resize(img, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
fg =cv2.resize(fg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
bg =cv2.resize(bg, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
alpha =cv2.resize(alpha, (self.size_w, self.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
trimap = gen_trimap(alpha)
|
||||
grad = compute_gradient(img)
|
||||
|
||||
if self.normalize:
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# first, 0-255 to 0-1
|
||||
# second, x-mean/std and HWC to CHW
|
||||
img_norm = self.normalize(img_rgb)
|
||||
else:
|
||||
img_norm = None
|
||||
|
||||
#img_id = img_info[0].split('/')[-1]
|
||||
#cv2.imwrite("result/debug/{}_img.png".format(img_id), img)
|
||||
#cv2.imwrite("result/debug/{}_alpha.png".format(img_id), alpha)
|
||||
#cv2.imwrite("result/debug/{}_fg.png".format(img_id), fg)
|
||||
#cv2.imwrite("result/debug/{}_bg.png".format(img_id), bg)
|
||||
#cv2.imwrite("result/debug/{}_trimap.png".format(img_id), trimap)
|
||||
#cv2.imwrite("result/debug/{}_grad.png".format(img_id), grad)
|
||||
alpha = torch.from_numpy(alpha.astype(np.float32)[np.newaxis, :, :])
|
||||
trimap = torch.from_numpy(trimap.astype(np.float32)[np.newaxis, :, :])
|
||||
grad = torch.from_numpy(grad.astype(np.float32)[np.newaxis, :, :])
|
||||
img = torch.from_numpy(img.astype(np.float32)).permute(2, 0, 1)
|
||||
fg = torch.from_numpy(fg.astype(np.float32)).permute(2, 0, 1)
|
||||
bg = torch.from_numpy(bg.astype(np.float32)).permute(2, 0, 1)
|
||||
|
||||
return img, alpha, fg, bg, trimap, grad, img_norm, img_info
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
@ -0,0 +1,36 @@
|
||||
import torch
|
||||
from argparse import Namespace
|
||||
import net
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
from deploy import inference_img_whole
|
||||
|
||||
# input file list
|
||||
image_path = "boy-1518482_1920_12_img.png"
|
||||
trimap_path = "boy-1518482_1920_12.png"
|
||||
image = cv2.imread(image_path)
|
||||
trimap = cv2.imread(trimap_path)
|
||||
# print(trimap.shape)
|
||||
trimap = trimap[:, :, 0]
|
||||
# init model
|
||||
args = Namespace(crop_or_resize='whole', cuda=True, max_size=1600, resume='model/stage1_sad_57.1.pth', stage=1)
|
||||
model = net.VGG16(args)
|
||||
ckpt = torch.load(args.resume)
|
||||
model.load_state_dict(ckpt['state_dict'], strict=True)
|
||||
model = model.cuda()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
pred_mattes = inference_img_whole(args, model, image, trimap)
|
||||
pred_mattes = (pred_mattes * 255).astype(np.uint8)
|
||||
pred_mattes[trimap == 255] = 255
|
||||
pred_mattes[trimap == 0] = 0
|
||||
# print(pred_mattes)
|
||||
# cv2.imwrite('out.png', pred_mattes)
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(image)
|
||||
# plt.show()
|
||||
|
||||
|
@ -0,0 +1,281 @@
|
||||
import torch
|
||||
import argparse
|
||||
import torch.nn as nn
|
||||
import net
|
||||
import cv2
|
||||
import os
|
||||
from torchvision import transforms
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
def get_args():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
|
||||
parser.add_argument('--size_h', type=int, default=320, help="height size of input image")
|
||||
parser.add_argument('--size_w', type=int, default=320, help="width size of input image")
|
||||
parser.add_argument('--imgDir', type=str, required=True, help="directory of image")
|
||||
parser.add_argument('--trimapDir', type=str, required=True, help="directory of trimap")
|
||||
parser.add_argument('--cuda', action='store_true', help='use cuda?')
|
||||
parser.add_argument('--resume', type=str, required=True, help="checkpoint that model resume from")
|
||||
parser.add_argument('--saveDir', type=str, required=True, help="where prediction result save to")
|
||||
parser.add_argument('--alphaDir', type=str, default='', help="directory of gt")
|
||||
parser.add_argument('--stage', type=int, required=True, choices=[0,1,2,3], help="backbone stage")
|
||||
parser.add_argument('--not_strict', action='store_true', help='not copy ckpt strict?')
|
||||
parser.add_argument('--crop_or_resize', type=str, default="whole", choices=["resize", "crop", "whole"], help="how manipulate image before test")
|
||||
parser.add_argument('--max_size', type=int, default=1600, help="max size of test image")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
def gen_dataset(imgdir, trimapdir):
|
||||
sample_set = []
|
||||
img_ids = os.listdir(imgdir)
|
||||
img_ids.sort()
|
||||
cnt = len(img_ids)
|
||||
cur = 1
|
||||
for img_id in img_ids:
|
||||
img_name = os.path.join(imgdir, img_id)
|
||||
trimap_name = os.path.join(trimapdir, img_id)
|
||||
|
||||
assert(os.path.exists(img_name))
|
||||
assert(os.path.exists(trimap_name))
|
||||
|
||||
sample_set.append((img_name, trimap_name))
|
||||
|
||||
return sample_set
|
||||
|
||||
def compute_gradient(img):
|
||||
x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
|
||||
y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
|
||||
absX = cv2.convertScaleAbs(x)
|
||||
absY = cv2.convertScaleAbs(y)
|
||||
grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
|
||||
grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY)
|
||||
return grad
|
||||
|
||||
|
||||
# inference once for image, return numpy
|
||||
def inference_once(args, model, scale_img, scale_trimap, aligned=True):
|
||||
|
||||
if aligned:
|
||||
assert(scale_img.shape[0] == args.size_h)
|
||||
assert(scale_img.shape[1] == args.size_w)
|
||||
|
||||
normalize = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
scale_img_rgb = cv2.cvtColor(scale_img, cv2.COLOR_BGR2RGB)
|
||||
# first, 0-255 to 0-1
|
||||
# second, x-mean/std and HWC to CHW
|
||||
tensor_img = normalize(scale_img_rgb).unsqueeze(0)
|
||||
|
||||
scale_grad = compute_gradient(scale_img)
|
||||
#tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2)
|
||||
tensor_trimap = torch.from_numpy(scale_trimap.astype(np.float32)[np.newaxis, np.newaxis, :, :])
|
||||
tensor_grad = torch.from_numpy(scale_grad.astype(np.float32)[np.newaxis, np.newaxis, :, :])
|
||||
|
||||
if args.cuda:
|
||||
tensor_img = tensor_img.cuda()
|
||||
tensor_trimap = tensor_trimap.cuda()
|
||||
tensor_grad = tensor_grad.cuda()
|
||||
#print('Img Shape:{} Trimap Shape:{}'.format(img.shape, trimap.shape))
|
||||
|
||||
input_t = torch.cat((tensor_img, tensor_trimap / 255.), 1)
|
||||
|
||||
# forward
|
||||
if args.stage <= 1:
|
||||
# stage 1
|
||||
pred_mattes, _ = model(input_t)
|
||||
else:
|
||||
# stage 2, 3
|
||||
_, pred_mattes = model(input_t)
|
||||
pred_mattes = pred_mattes.data
|
||||
if args.cuda:
|
||||
pred_mattes = pred_mattes.cpu()
|
||||
pred_mattes = pred_mattes.numpy()[0, 0, :, :]
|
||||
return pred_mattes
|
||||
|
||||
|
||||
|
||||
# forward for a full image by crop method
|
||||
def inference_img_by_crop(args, model, img, trimap):
|
||||
# crop the pictures, and forward one by one
|
||||
h, w, c = img.shape
|
||||
origin_pred_mattes = np.zeros((h, w), dtype=np.float32)
|
||||
marks = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for start_h in range(0, h, args.size_h):
|
||||
end_h = start_h + args.size_h
|
||||
for start_w in range(0, w, args.size_w):
|
||||
|
||||
end_w = start_w + args.size_w
|
||||
crop_img = img[start_h: end_h, start_w: end_w, :]
|
||||
crop_trimap = trimap[start_h: end_h, start_w: end_w]
|
||||
|
||||
crop_origin_h = crop_img.shape[0]
|
||||
crop_origin_w = crop_img.shape[1]
|
||||
|
||||
#print("startH:{} startW:{} H:{} W:{}".format(start_h, start_w, crop_origin_h, crop_origin_w))
|
||||
|
||||
if len(np.where(crop_trimap == 128)[0]) <= 0:
|
||||
continue
|
||||
|
||||
# egde patch in the right or bottom
|
||||
if crop_origin_h != args.size_h or crop_origin_w != args.size_w:
|
||||
crop_img = cv2.resize(crop_img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
crop_trimap = cv2.resize(crop_trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# inference for each crop image patch
|
||||
pred_mattes = inference_once(args, model, crop_img, crop_trimap)
|
||||
|
||||
if crop_origin_h != args.size_h or crop_origin_w != args.size_w:
|
||||
pred_mattes = cv2.resize(pred_mattes, (crop_origin_w, crop_origin_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
origin_pred_mattes[start_h: end_h, start_w: end_w] += pred_mattes
|
||||
marks[start_h: end_h, start_w: end_w] += 1
|
||||
|
||||
# smooth for overlap part
|
||||
marks[marks <= 0] = 1.
|
||||
origin_pred_mattes /= marks
|
||||
return origin_pred_mattes
|
||||
|
||||
|
||||
# forward for a full image by resize method
|
||||
def inference_img_by_resize(args, model, img, trimap):
|
||||
h, w, c = img.shape
|
||||
# resize for network input, to Tensor
|
||||
scale_img = cv2.resize(img, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
scale_trimap = cv2.resize(trimap, (args.size_w, args.size_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
pred_mattes = inference_once(args, model, scale_img, scale_trimap)
|
||||
|
||||
# resize to origin size
|
||||
origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR)
|
||||
assert(origin_pred_mattes.shape == trimap.shape)
|
||||
return origin_pred_mattes
|
||||
|
||||
|
||||
# forward a whole image
|
||||
def inference_img_whole(args, model, img, trimap):
|
||||
h, w, c = img.shape
|
||||
new_h = min(args.max_size, h - (h % 32))
|
||||
new_w = min(args.max_size, w - (w % 32))
|
||||
|
||||
# resize for network input, to Tensor
|
||||
scale_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
scale_trimap = cv2.resize(trimap, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
pred_mattes = inference_once(args, model, scale_img, scale_trimap, aligned=False)
|
||||
|
||||
# resize to origin size
|
||||
origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR)
|
||||
assert(origin_pred_mattes.shape == trimap.shape)
|
||||
return origin_pred_mattes
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
print("===> Loading args")
|
||||
args = get_args()
|
||||
|
||||
print("===> Environment init")
|
||||
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
if args.cuda and not torch.cuda.is_available():
|
||||
raise Exception("No GPU found, please run without --cuda")
|
||||
|
||||
model = net.VGG16(args)
|
||||
ckpt = torch.load(args.resume)
|
||||
if args.not_strict:
|
||||
model.load_state_dict(ckpt['state_dict'], strict=False)
|
||||
else:
|
||||
model.load_state_dict(ckpt['state_dict'], strict=True)
|
||||
|
||||
if args.cuda:
|
||||
model = model.cuda()
|
||||
|
||||
print("===> Load dataset")
|
||||
dataset = gen_dataset(args.imgDir, args.trimapDir)
|
||||
|
||||
mse_diffs = 0.
|
||||
sad_diffs = 0.
|
||||
cnt = len(dataset)
|
||||
cur = 0
|
||||
t0 = time.time()
|
||||
for img_path, trimap_path in dataset:
|
||||
img = cv2.imread(img_path)
|
||||
trimap = cv2.imread(trimap_path)[:, :, 0]
|
||||
|
||||
assert(img.shape[:2] == trimap.shape[:2])
|
||||
|
||||
img_info = (img_path.split('/')[-1], img.shape[0], img.shape[1])
|
||||
|
||||
cur += 1
|
||||
print('[{}/{}] {}'.format(cur, cnt, img_info[0]))
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.crop_or_resize == "whole":
|
||||
origin_pred_mattes = inference_img_whole(args, model, img, trimap)
|
||||
elif args.crop_or_resize == "crop":
|
||||
origin_pred_mattes = inference_img_by_crop(args, model, img, trimap)
|
||||
else:
|
||||
origin_pred_mattes = inference_img_by_resize(args, model, img, trimap)
|
||||
|
||||
# only attention unknown region
|
||||
origin_pred_mattes[trimap == 255] = 1.
|
||||
origin_pred_mattes[trimap == 0 ] = 0.
|
||||
|
||||
# origin trimap
|
||||
pixel = float((trimap == 128).sum())
|
||||
|
||||
# eval if gt alpha is given
|
||||
if args.alphaDir != '':
|
||||
alpha_name = os.path.join(args.alphaDir, img_info[0])
|
||||
assert(os.path.exists(alpha_name))
|
||||
alpha = cv2.imread(alpha_name)[:, :, 0] / 255.
|
||||
assert(alpha.shape == origin_pred_mattes.shape)
|
||||
|
||||
#x1 = (alpha[trimap == 255] == 1.0).sum() # x3
|
||||
#x2 = (alpha[trimap == 0] == 0.0).sum() # x5
|
||||
#x3 = (trimap == 255).sum()
|
||||
#x4 = (trimap == 128).sum()
|
||||
#x5 = (trimap == 0).sum()
|
||||
#x6 = trimap.size # sum(x3,x4,x5)
|
||||
#x7 = (alpha[trimap == 255] < 1.0).sum() # 0
|
||||
#x8 = (alpha[trimap == 0] > 0).sum() #
|
||||
|
||||
#print(x1, x2, x3, x4, x5, x6, x7, x8)
|
||||
#assert(x1 == x3)
|
||||
#assert(x2 == x5)
|
||||
#assert(x6 == x3 + x4 + x5)
|
||||
#assert(x7 == 0)
|
||||
#assert(x8 == 0)
|
||||
|
||||
mse_diff = ((origin_pred_mattes - alpha) ** 2).sum() / pixel
|
||||
sad_diff = np.abs(origin_pred_mattes - alpha).sum()
|
||||
mse_diffs += mse_diff
|
||||
sad_diffs += sad_diff
|
||||
print("sad:{} mse:{}".format(sad_diff, mse_diff))
|
||||
|
||||
origin_pred_mattes = (origin_pred_mattes * 255).astype(np.uint8)
|
||||
res = origin_pred_mattes.copy()
|
||||
|
||||
# only attention unknown region
|
||||
res[trimap == 255] = 255
|
||||
res[trimap == 0 ] = 0
|
||||
|
||||
if not os.path.exists(args.saveDir):
|
||||
os.makedirs(args.saveDir)
|
||||
cv2.imwrite(os.path.join(args.saveDir, img_info[0]), res)
|
||||
|
||||
print("Avg-Cost: {} s/image".format((time.time() - t0) / cnt))
|
||||
if args.alphaDir != '':
|
||||
print("Eval-MSE: {}".format(mse_diffs / cur))
|
||||
print("Eval-SAD: {}".format(sad_diffs / cur))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Binary file not shown.
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import cv2
|
||||
import torch.nn.functional as F
|
||||
|
||||
class VGG16(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(VGG16, self).__init__()
|
||||
self.stage = args.stage
|
||||
|
||||
self.conv1_1 = nn.Conv2d(4, 64, kernel_size=3,stride = 1, padding=1,bias=True)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3,stride = 1, padding=1,bias=True)
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=True)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=True)
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1,bias=True)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=True)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
|
||||
# model released before 2019.09.09 should use kernel_size=1 & padding=0
|
||||
#self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, padding=0,bias=True)
|
||||
self.conv6_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
||||
|
||||
self.deconv6_1 = nn.Conv2d(512, 512, kernel_size=1,bias=True)
|
||||
self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2,bias=True)
|
||||
self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2,bias=True)
|
||||
self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2,bias=True)
|
||||
self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2,bias=True)
|
||||
self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2,bias=True)
|
||||
|
||||
self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2,bias=True)
|
||||
|
||||
if args.stage == 2:
|
||||
# for stage2 training
|
||||
for p in self.parameters():
|
||||
p.requires_grad=False
|
||||
|
||||
if self.stage == 2 or self.stage == 3:
|
||||
self.refine_conv1 = nn.Conv2d(4, 64, kernel_size=3, padding=1, bias=True)
|
||||
self.refine_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
|
||||
self.refine_conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
|
||||
self.refine_pred = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
# Stage 1
|
||||
x11 = F.relu(self.conv1_1(x))
|
||||
x12 = F.relu(self.conv1_2(x11))
|
||||
x1p, id1 = F.max_pool2d(x12,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
||||
|
||||
# Stage 2
|
||||
x21 = F.relu(self.conv2_1(x1p))
|
||||
x22 = F.relu(self.conv2_2(x21))
|
||||
x2p, id2 = F.max_pool2d(x22,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
||||
|
||||
# Stage 3
|
||||
x31 = F.relu(self.conv3_1(x2p))
|
||||
x32 = F.relu(self.conv3_2(x31))
|
||||
x33 = F.relu(self.conv3_3(x32))
|
||||
x3p, id3 = F.max_pool2d(x33,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
||||
|
||||
# Stage 4
|
||||
x41 = F.relu(self.conv4_1(x3p))
|
||||
x42 = F.relu(self.conv4_2(x41))
|
||||
x43 = F.relu(self.conv4_3(x42))
|
||||
x4p, id4 = F.max_pool2d(x43,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
||||
|
||||
# Stage 5
|
||||
x51 = F.relu(self.conv5_1(x4p))
|
||||
x52 = F.relu(self.conv5_2(x51))
|
||||
x53 = F.relu(self.conv5_3(x52))
|
||||
x5p, id5 = F.max_pool2d(x53,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
||||
|
||||
# Stage 6
|
||||
x61 = F.relu(self.conv6_1(x5p))
|
||||
|
||||
# Stage 6d
|
||||
x61d = F.relu(self.deconv6_1(x61))
|
||||
|
||||
# Stage 5d
|
||||
x5d = F.max_unpool2d(x61d,id5, kernel_size=2, stride=2)
|
||||
x51d = F.relu(self.deconv5_1(x5d))
|
||||
|
||||
# Stage 4d
|
||||
x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
|
||||
x41d = F.relu(self.deconv4_1(x4d))
|
||||
|
||||
# Stage 3d
|
||||
x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
|
||||
x31d = F.relu(self.deconv3_1(x3d))
|
||||
|
||||
# Stage 2d
|
||||
x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
|
||||
x21d = F.relu(self.deconv2_1(x2d))
|
||||
|
||||
# Stage 1d
|
||||
x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
|
||||
x12d = F.relu(self.deconv1_1(x1d))
|
||||
|
||||
# Should add sigmoid? github repo add so.
|
||||
raw_alpha = self.deconv1(x12d)
|
||||
pred_mattes = F.sigmoid(raw_alpha)
|
||||
|
||||
if self.stage <= 1:
|
||||
return pred_mattes, 0
|
||||
|
||||
# Stage2 refine conv1
|
||||
refine0 = torch.cat((x[:, :3, :, :], pred_mattes), 1)
|
||||
refine1 = F.relu(self.refine_conv1(refine0))
|
||||
refine2 = F.relu(self.refine_conv2(refine1))
|
||||
refine3 = F.relu(self.refine_conv3(refine2))
|
||||
# Should add sigmoid?
|
||||
# sigmoid lead to refine result all converge to 0...
|
||||
#pred_refine = F.sigmoid(self.refine_pred(refine3))
|
||||
pred_refine = self.refine_pred(refine3)
|
||||
|
||||
pred_alpha = F.sigmoid(raw_alpha + pred_refine)
|
||||
|
||||
#print(pred_mattes.mean(), pred_alpha.mean(), pred_refine.sum())
|
||||
|
||||
return pred_mattes, pred_alpha
|
Binary file not shown.
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import collections
|
||||
import os
|
||||
|
||||
HOME = os.environ['HOME']
|
||||
model_path = "{}/.torch/models/vgg16-397923af.pth".format(HOME)
|
||||
#model_path = "/data/liuliang/deep_image_matting/train/vgg16-397923af.pth"
|
||||
if not os.path.exists(model_path):
|
||||
model = torchvision.models.vgg16(pretrained=True)
|
||||
assert(os.path.exists(model_path))
|
||||
|
||||
x = torch.load(model_path)
|
||||
|
||||
val = collections.OrderedDict()
|
||||
val['conv1_1.weight'] = torch.cat((x['features.0.weight'], torch.zeros(64, 1, 3, 3)), 1)
|
||||
|
||||
replace = { u'features.0.bias' : 'conv1_1.bias',
|
||||
u'features.2.weight' : 'conv1_2.weight',
|
||||
u'features.2.bias' : 'conv1_2.bias',
|
||||
u'features.5.weight' : 'conv2_1.weight',
|
||||
u'features.5.bias' : 'conv2_1.bias',
|
||||
u'features.7.weight' : 'conv2_2.weight',
|
||||
u'features.7.bias' : 'conv2_2.bias',
|
||||
u'features.10.weight': 'conv3_1.weight',
|
||||
u'features.10.bias' : 'conv3_1.bias',
|
||||
u'features.12.weight': 'conv3_2.weight',
|
||||
u'features.12.bias' : 'conv3_2.bias',
|
||||
u'features.14.weight': 'conv3_3.weight',
|
||||
u'features.14.bias' : 'conv3_3.bias',
|
||||
u'features.17.weight': 'conv4_1.weight',
|
||||
u'features.17.bias' : 'conv4_1.bias',
|
||||
u'features.19.weight': 'conv4_2.weight',
|
||||
u'features.19.bias' : 'conv4_2.bias',
|
||||
u'features.21.weight': 'conv4_3.weight',
|
||||
u'features.21.bias' : 'conv4_3.bias',
|
||||
u'features.24.weight': 'conv5_1.weight',
|
||||
u'features.24.bias' : 'conv5_1.bias',
|
||||
u'features.26.weight': 'conv5_2.weight',
|
||||
u'features.26.bias' : 'conv5_2.bias',
|
||||
u'features.28.weight': 'conv5_3.weight',
|
||||
u'features.28.bias' : 'conv5_3.bias'
|
||||
}
|
||||
|
||||
#print(x['classifier.0.weight'].shape)
|
||||
#print(x['classifier.0.bias'].shape)
|
||||
|
||||
#tmp1 = x['classifier.0.weight'].reshape(4096, 512, 7, 7)
|
||||
#print(tmp1.shape)
|
||||
|
||||
#val['conv6_1.weight'] = tmp1[:512, :, :, :]
|
||||
#val['conv6_1.bias'] = x['classifier.0.bias']
|
||||
|
||||
for key in replace.keys():
|
||||
print(key, replace[key])
|
||||
val[replace[key]] = x[key]
|
||||
|
||||
y = {}
|
||||
y['state_dict'] = val
|
||||
y['epoch'] = 0
|
||||
if not os.path.exists('./model'):
|
||||
os.makedirs('./model')
|
||||
torch.save(y, './model/vgg_state_dict.pth')
|
@ -0,0 +1,137 @@
|
||||
# composite image with dataset from "deep image matting"
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import time
|
||||
import shutil
|
||||
|
||||
root_dir = "/home/liuliang/Downloads/Combined_Dataset"
|
||||
test_bg_dir = '/home/liuliang/Desktop/dataset/matting/VOCdevkit/VOC2012/JPEGImages'
|
||||
train_bg_dir = '/home/liuliang/Desktop/dataset/matting/mscoco/train2017'
|
||||
|
||||
def my_composite(fg_names, bg_names, fg_dir, alpha_dir, bg_dir, num_bg, comp_dir):
|
||||
fg_ids = open(fg_names).readlines()
|
||||
bg_ids = open(bg_names).readlines()
|
||||
|
||||
fg_cnt = len(fg_ids)
|
||||
bg_cnt = len(bg_ids)
|
||||
print(fg_cnt, bg_cnt)
|
||||
assert(fg_cnt * num_bg == bg_cnt)
|
||||
|
||||
for i in range(fg_cnt):
|
||||
im_name = fg_ids[i].strip("\n").strip("\r")
|
||||
fg_path = os.path.join(fg_dir, im_name)
|
||||
alpha_path = os.path.join(alpha_dir, im_name)
|
||||
|
||||
#print(fg_path, alpha_path)
|
||||
assert(os.path.exists(fg_path))
|
||||
assert(os.path.exists(alpha_path))
|
||||
|
||||
fg = cv2.imread(fg_path)
|
||||
alpha = cv2.imread(alpha_path)
|
||||
|
||||
#print("alpha shape:", alpha.shape, "image shape:", fg.shape)
|
||||
assert(alpha.shape == fg.shape)
|
||||
|
||||
h, w ,c = fg.shape
|
||||
base = i * num_bg
|
||||
for bcount in range(num_bg):
|
||||
bg_path = os.path.join(bg_dir, bg_ids[base + bcount].strip("\n").strip("\r"))
|
||||
print(base + bcount, fg_path, bg_path)
|
||||
assert(os.path.exists(bg_path))
|
||||
bg = cv2.imread(bg_path)
|
||||
bh, bw, bc = bg.shape
|
||||
|
||||
wratio = float(w) / bw
|
||||
hratio = float(h) / bh
|
||||
ratio = wratio if wratio > hratio else hratio
|
||||
if ratio > 1:
|
||||
new_bw = int(bw * ratio + 1.0)
|
||||
new_bh = int(bh * ratio + 1.0)
|
||||
bg = cv2.resize(bg, (new_bw, new_bh), interpolation=cv2.INTER_LINEAR)
|
||||
bg = bg[0 : h, 0 : w, :]
|
||||
#print(bg.shape)
|
||||
assert(bg.shape == fg.shape)
|
||||
alpha_f = alpha / 255.
|
||||
comp = fg * alpha_f + bg * (1. - alpha_f)
|
||||
|
||||
img_save_id = im_name[:len(im_name)-4] + '_' + str(bcount) + '.png'
|
||||
comp_save_path = os.path.join(comp_dir, "image/" + img_save_id)
|
||||
fg_save_path = os.path.join(comp_dir, "fg/" + img_save_id)
|
||||
bg_save_path = os.path.join(comp_dir, "bg/" + img_save_id)
|
||||
alpha_save_path = os.path.join(comp_dir, "alpha/" + img_save_id)
|
||||
|
||||
cv2.imwrite(comp_save_path, comp)
|
||||
cv2.imwrite(fg_save_path, fg)
|
||||
cv2.imwrite(bg_save_path, bg)
|
||||
cv2.imwrite(alpha_save_path, alpha)
|
||||
|
||||
|
||||
def copy_dir2dir(src_dir, des_dir):
|
||||
for img_id in os.listdir(src_dir):
|
||||
shutil.copyfile(os.path.join(src_dir, img_id), os.path.join(des_dir, img_id))
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
test_num_bg = 20
|
||||
test_fg_names = os.path.join(root_dir, "Test_set/test_fg_names.txt")
|
||||
test_bg_names = os.path.join(root_dir, "Test_set/test_bg_names.txt")
|
||||
test_fg_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/fg")
|
||||
test_alpha_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/alpha")
|
||||
test_trimap_dir = os.path.join(root_dir, "Test_set/Adobe-licensed images/trimaps")
|
||||
test_comp_dir = os.path.join(root_dir, "Test_set/comp")
|
||||
|
||||
train_num_bg = 100
|
||||
train_fg_names = os.path.join(root_dir, "Training_set/training_fg_names.txt")
|
||||
train_bg_names_coco2014 = os.path.join(root_dir, "Training_set/training_bg_names.txt")
|
||||
train_bg_names_coco2017 = os.path.join(root_dir, "Training_set/training_bg_names_coco2017.txt")
|
||||
train_fg_dir = os.path.join(root_dir, "Training_set/all/fg")
|
||||
train_alpha_dir = os.path.join(root_dir, "Training_set/all/alpha")
|
||||
train_comp_dir = os.path.join(root_dir, "Training_set/comp")
|
||||
|
||||
# change the bg names formate if is coco 2017
|
||||
fin = open(train_bg_names_coco2014, 'r')
|
||||
fout = open(train_bg_names_coco2017, 'w')
|
||||
lls = fin.readlines()
|
||||
for l in lls:
|
||||
fout.write(l[15:])
|
||||
fin.close()
|
||||
fout.close()
|
||||
|
||||
if not os.path.exists(test_comp_dir):
|
||||
os.makedirs(test_comp_dir + '/image')
|
||||
os.makedirs(test_comp_dir + '/fg')
|
||||
os.makedirs(test_comp_dir + '/bg')
|
||||
os.makedirs(test_comp_dir + '/alpha')
|
||||
os.makedirs(test_comp_dir + '/trimap')
|
||||
|
||||
if not os.path.exists(train_comp_dir):
|
||||
os.makedirs(train_comp_dir + '/image')
|
||||
os.makedirs(train_comp_dir + '/fg')
|
||||
os.makedirs(train_comp_dir + '/bg')
|
||||
os.makedirs(train_comp_dir + '/alpha')
|
||||
|
||||
if not os.path.exists(train_alpha_dir):
|
||||
os.makedirs(train_alpha_dir)
|
||||
|
||||
if not os.path.exists(train_fg_dir):
|
||||
os.makedirs(train_fg_dir)
|
||||
|
||||
# copy test trimaps
|
||||
copy_dir2dir(test_trimap_dir, test_comp_dir + '/trimap')
|
||||
# copy train images together
|
||||
copy_dir2dir(os.path.join(root_dir, "Training_set/Adobe-licensed images/alpha"), train_alpha_dir)
|
||||
copy_dir2dir(os.path.join(root_dir, "Training_set/Adobe-licensed images/fg"), train_fg_dir)
|
||||
copy_dir2dir(os.path.join(root_dir, "Training_set/Other/alpha"), train_alpha_dir)
|
||||
copy_dir2dir(os.path.join(root_dir, "Training_set/Other/fg"), train_fg_dir)
|
||||
|
||||
# composite test image
|
||||
my_composite(test_fg_names, test_bg_names, test_fg_dir, test_alpha_dir, test_bg_dir, test_num_bg, test_comp_dir)
|
||||
# composite train image
|
||||
my_composite(train_fg_names, train_bg_names_coco2017, train_fg_dir, train_alpha_dir, train_bg_dir, train_num_bg, train_comp_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -0,0 +1,50 @@
|
||||
import re
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pylab import *
|
||||
|
||||
# args: log_name, match_rule, self_log_interval, smooth_log_interation
|
||||
loss_file_name = "simple_loss"
|
||||
title = "{}_Loss".format(loss_file_name)
|
||||
f = open("../log/{}.log".format(loss_file_name))
|
||||
pattern = re.compile(r'Loss:[ ]*\d+\.\d+')
|
||||
self_inter = 10
|
||||
smooth = 20
|
||||
|
||||
# read log file
|
||||
lines = f.readlines()
|
||||
print("Line: {}".format(len(lines)))
|
||||
ys = []
|
||||
k = 0
|
||||
cnt = 0
|
||||
sum_y = 0.
|
||||
|
||||
# read one by one
|
||||
for line in lines:
|
||||
obj = re.search(pattern, line)
|
||||
if obj:
|
||||
val = float(obj.group().split(':')[-1])
|
||||
sum_y += val
|
||||
k += 1
|
||||
if k >= smooth:
|
||||
ys.append(sum_y / k)
|
||||
sum_y = 0.
|
||||
k = 0
|
||||
cnt += 1
|
||||
if cnt % 10 == 0:
|
||||
print("ys cnt: {}".format(cnt))
|
||||
if k > 0:
|
||||
ys.append(sum_y / k)
|
||||
|
||||
ys = np.array(ys)
|
||||
xs = np.arange(len(ys)) * self_inter * smooth
|
||||
|
||||
print(xs)
|
||||
print(ys)
|
||||
|
||||
plt.plot(xs, ys)
|
||||
plt.title(title)
|
||||
plt.xlabel("Iter")
|
||||
plt.ylabel("Loss")
|
||||
plt.savefig("../log/{}.png".format(title))
|
||||
plt.show()
|
Loading…
Reference in New Issue