GIMP-ML/gimp-plugins/DeblurGANv2/aug.py
2020-04-27 10:02:33 +05:30

94 lines
3.3 KiB
Python

from typing import List
import albumentations as albu
def get_transforms(size, scope = 'geometric', crop='random'):
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
albu.ElasticTransform(),
albu.OpticalDistortion(),
albu.OneOf([
albu.CLAHE(clip_limit=2),
albu.IAASharpen(),
albu.IAAEmboss(),
albu.RandomBrightnessContrast(),
albu.RandomGamma()
], p=0.5),
albu.OneOf([
albu.RGBShift(),
albu.HueSaturationValue(),
], p=0.5),
]),
'weak': albu.Compose([albu.HorizontalFlip(),
]),
'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True),
albu.ShiftScaleRotate(always_apply=True),
albu.Transpose(always_apply=True),
albu.OpticalDistortion(always_apply=True),
albu.ElasticTransform(always_apply=True),
])
}
aug_fn = augs[scope]
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
pad = albu.PadIfNeeded(size, size)
pipeline = albu.Compose([aug_fn, crop_fn, pad], additional_targets={'target': 'image'})
def process(a, b):
r = pipeline(image=a, target=b)
return r['image'], r['target']
return process
def get_normalize():
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
def process(a, b):
r = normalize(image=a, target=b)
return r['image'], r['target']
return process
def _resolve_aug_fn(name):
d = {
'cutout': albu.Cutout,
'rgb_shift': albu.RGBShift,
'hsv_shift': albu.HueSaturationValue,
'motion_blur': albu.MotionBlur,
'median_blur': albu.MedianBlur,
'snow': albu.RandomSnow,
'shadow': albu.RandomShadow,
'fog': albu.RandomFog,
'brightness_contrast': albu.RandomBrightnessContrast,
'gamma': albu.RandomGamma,
'sun_flare': albu.RandomSunFlare,
'sharpen': albu.IAASharpen,
'jpeg': albu.JpegCompression,
'gray': albu.ToGray,
# ToDo: pixelize
# ToDo: partial gray
}
return d[name]
def get_corrupt_function(config):
augs = []
for aug_params in config:
name = aug_params.pop('name')
cls = _resolve_aug_fn(name)
prob = aug_params.pop('prob') if 'prob' in aug_params else .5
augs.append(cls(p=prob, **aug_params))
augs = albu.OneOf(augs)
def process(x):
return augs(image=x)['image']
return process