mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
37 lines
959 B
Python
Executable File
37 lines
959 B
Python
Executable File
import torch
|
|
from argparse import Namespace
|
|
import net
|
|
import cv2
|
|
import os
|
|
import numpy as np
|
|
from deploy import inference_img_whole
|
|
|
|
# input file list
|
|
image_path = "boy-1518482_1920_12_img.png"
|
|
trimap_path = "boy-1518482_1920_12.png"
|
|
image = cv2.imread(image_path)
|
|
trimap = cv2.imread(trimap_path)
|
|
# print(trimap.shape)
|
|
trimap = trimap[:, :, 0]
|
|
# init model
|
|
args = Namespace(crop_or_resize='whole', cuda=True, max_size=1600, resume='model/stage1_sad_57.1.pth', stage=1)
|
|
model = net.VGG16(args)
|
|
ckpt = torch.load(args.resume)
|
|
model.load_state_dict(ckpt['state_dict'], strict=True)
|
|
model = model.cuda()
|
|
|
|
torch.cuda.empty_cache()
|
|
with torch.no_grad():
|
|
pred_mattes = inference_img_whole(args, model, image, trimap)
|
|
pred_mattes = (pred_mattes * 255).astype(np.uint8)
|
|
pred_mattes[trimap == 255] = 255
|
|
pred_mattes[trimap == 0] = 0
|
|
# print(pred_mattes)
|
|
# cv2.imwrite('out.png', pred_mattes)
|
|
|
|
# import matplotlib.pyplot as plt
|
|
# plt.imshow(image)
|
|
# plt.show()
|
|
|
|
|