mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-02 03:40:29 +00:00
107 lines
2.9 KiB
Python
Executable File
107 lines
2.9 KiB
Python
Executable File
#!/usr/bin/python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import torchvision.transforms as transforms
|
|
|
|
import os.path as osp
|
|
import os
|
|
from PIL import Image
|
|
import numpy as np
|
|
import json
|
|
import cv2
|
|
|
|
from transform import *
|
|
|
|
|
|
|
|
class FaceMask(Dataset):
|
|
def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
|
|
super(FaceMask, self).__init__(*args, **kwargs)
|
|
assert mode in ('train', 'val', 'test')
|
|
self.mode = mode
|
|
self.ignore_lb = 255
|
|
self.rootpth = rootpth
|
|
|
|
self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
|
|
|
|
# pre-processing
|
|
self.to_tensor = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
|
])
|
|
self.trans_train = Compose([
|
|
ColorJitter(
|
|
brightness=0.5,
|
|
contrast=0.5,
|
|
saturation=0.5),
|
|
HorizontalFlip(),
|
|
RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
|
|
RandomCrop(cropsize)
|
|
])
|
|
|
|
def __getitem__(self, idx):
|
|
impth = self.imgs[idx]
|
|
img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
|
|
img = img.resize((512, 512), Image.BILINEAR)
|
|
label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
|
|
# print(np.unique(np.array(label)))
|
|
if self.mode == 'train':
|
|
im_lb = dict(im=img, lb=label)
|
|
im_lb = self.trans_train(im_lb)
|
|
img, label = im_lb['im'], im_lb['lb']
|
|
img = self.to_tensor(img)
|
|
label = np.array(label).astype(np.int64)[np.newaxis, :]
|
|
return img, label
|
|
|
|
def __len__(self):
|
|
return len(self.imgs)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
|
|
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
|
|
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
|
|
counter = 0
|
|
total = 0
|
|
for i in range(15):
|
|
# files = os.listdir(osp.join(face_sep_mask, str(i)))
|
|
|
|
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
|
|
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
|
|
|
|
for j in range(i*2000, (i+1)*2000):
|
|
|
|
mask = np.zeros((512, 512))
|
|
|
|
for l, att in enumerate(atts, 1):
|
|
total += 1
|
|
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
|
|
path = osp.join(face_sep_mask, str(i), file_name)
|
|
|
|
if os.path.exists(path):
|
|
counter += 1
|
|
sep_mask = np.array(Image.open(path).convert('P'))
|
|
# print(np.unique(sep_mask))
|
|
|
|
mask[sep_mask == 225] = l
|
|
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
|
|
print(j)
|
|
|
|
print(counter, total)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|