GIMP-ML/gimp-plugins/pytorch-deep-image-matting/demo.py

37 lines
959 B
Python
Raw Normal View History

2020-08-02 11:54:31 +00:00
import torch
from argparse import Namespace
import net
import cv2
import os
import numpy as np
from deploy import inference_img_whole
# input file list
image_path = "boy-1518482_1920_12_img.png"
trimap_path = "boy-1518482_1920_12.png"
image = cv2.imread(image_path)
trimap = cv2.imread(trimap_path)
# print(trimap.shape)
trimap = trimap[:, :, 0]
# init model
args = Namespace(crop_or_resize='whole', cuda=True, max_size=1600, resume='model/stage1_sad_57.1.pth', stage=1)
model = net.VGG16(args)
ckpt = torch.load(args.resume)
model.load_state_dict(ckpt['state_dict'], strict=True)
model = model.cuda()
torch.cuda.empty_cache()
with torch.no_grad():
pred_mattes = inference_img_whole(args, model, image, trimap)
pred_mattes = (pred_mattes * 255).astype(np.uint8)
pred_mattes[trimap == 255] = 255
pred_mattes[trimap == 0] = 0
# print(pred_mattes)
# cv2.imwrite('out.png', pred_mattes)
# import matplotlib.pyplot as plt
# plt.imshow(image)
# plt.show()