mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-11-06 03:20:34 +00:00
130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
|
#!/usr/bin/python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
|
||
|
from PIL import Image
|
||
|
import PIL.ImageEnhance as ImageEnhance
|
||
|
import random
|
||
|
import numpy as np
|
||
|
|
||
|
class RandomCrop(object):
|
||
|
def __init__(self, size, *args, **kwargs):
|
||
|
self.size = size
|
||
|
|
||
|
def __call__(self, im_lb):
|
||
|
im = im_lb['im']
|
||
|
lb = im_lb['lb']
|
||
|
assert im.size == lb.size
|
||
|
W, H = self.size
|
||
|
w, h = im.size
|
||
|
|
||
|
if (W, H) == (w, h): return dict(im=im, lb=lb)
|
||
|
if w < W or h < H:
|
||
|
scale = float(W) / w if w < h else float(H) / h
|
||
|
w, h = int(scale * w + 1), int(scale * h + 1)
|
||
|
im = im.resize((w, h), Image.BILINEAR)
|
||
|
lb = lb.resize((w, h), Image.NEAREST)
|
||
|
sw, sh = random.random() * (w - W), random.random() * (h - H)
|
||
|
crop = int(sw), int(sh), int(sw) + W, int(sh) + H
|
||
|
return dict(
|
||
|
im = im.crop(crop),
|
||
|
lb = lb.crop(crop)
|
||
|
)
|
||
|
|
||
|
|
||
|
class HorizontalFlip(object):
|
||
|
def __init__(self, p=0.5, *args, **kwargs):
|
||
|
self.p = p
|
||
|
|
||
|
def __call__(self, im_lb):
|
||
|
if random.random() > self.p:
|
||
|
return im_lb
|
||
|
else:
|
||
|
im = im_lb['im']
|
||
|
lb = im_lb['lb']
|
||
|
|
||
|
# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
|
||
|
# 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
|
||
|
|
||
|
flip_lb = np.array(lb)
|
||
|
flip_lb[lb == 2] = 3
|
||
|
flip_lb[lb == 3] = 2
|
||
|
flip_lb[lb == 4] = 5
|
||
|
flip_lb[lb == 5] = 4
|
||
|
flip_lb[lb == 7] = 8
|
||
|
flip_lb[lb == 8] = 7
|
||
|
flip_lb = Image.fromarray(flip_lb)
|
||
|
return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
|
||
|
lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
|
||
|
)
|
||
|
|
||
|
|
||
|
class RandomScale(object):
|
||
|
def __init__(self, scales=(1, ), *args, **kwargs):
|
||
|
self.scales = scales
|
||
|
|
||
|
def __call__(self, im_lb):
|
||
|
im = im_lb['im']
|
||
|
lb = im_lb['lb']
|
||
|
W, H = im.size
|
||
|
scale = random.choice(self.scales)
|
||
|
w, h = int(W * scale), int(H * scale)
|
||
|
return dict(im = im.resize((w, h), Image.BILINEAR),
|
||
|
lb = lb.resize((w, h), Image.NEAREST),
|
||
|
)
|
||
|
|
||
|
|
||
|
class ColorJitter(object):
|
||
|
def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
|
||
|
if not brightness is None and brightness>0:
|
||
|
self.brightness = [max(1-brightness, 0), 1+brightness]
|
||
|
if not contrast is None and contrast>0:
|
||
|
self.contrast = [max(1-contrast, 0), 1+contrast]
|
||
|
if not saturation is None and saturation>0:
|
||
|
self.saturation = [max(1-saturation, 0), 1+saturation]
|
||
|
|
||
|
def __call__(self, im_lb):
|
||
|
im = im_lb['im']
|
||
|
lb = im_lb['lb']
|
||
|
r_brightness = random.uniform(self.brightness[0], self.brightness[1])
|
||
|
r_contrast = random.uniform(self.contrast[0], self.contrast[1])
|
||
|
r_saturation = random.uniform(self.saturation[0], self.saturation[1])
|
||
|
im = ImageEnhance.Brightness(im).enhance(r_brightness)
|
||
|
im = ImageEnhance.Contrast(im).enhance(r_contrast)
|
||
|
im = ImageEnhance.Color(im).enhance(r_saturation)
|
||
|
return dict(im = im,
|
||
|
lb = lb,
|
||
|
)
|
||
|
|
||
|
|
||
|
class MultiScale(object):
|
||
|
def __init__(self, scales):
|
||
|
self.scales = scales
|
||
|
|
||
|
def __call__(self, img):
|
||
|
W, H = img.size
|
||
|
sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
|
||
|
imgs = []
|
||
|
[imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
|
||
|
return imgs
|
||
|
|
||
|
|
||
|
class Compose(object):
|
||
|
def __init__(self, do_list):
|
||
|
self.do_list = do_list
|
||
|
|
||
|
def __call__(self, im_lb):
|
||
|
for comp in self.do_list:
|
||
|
im_lb = comp(im_lb)
|
||
|
return im_lb
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
flip = HorizontalFlip(p = 1)
|
||
|
crop = RandomCrop((321, 321))
|
||
|
rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
|
||
|
img = Image.open('data/img.jpg')
|
||
|
lb = Image.open('data/label.png')
|