diff --git a/gimp-plugins/Inpainting/DFNet_core.py b/gimp-plugins/Inpainting/DFNet_core.py new file mode 100644 index 0000000..1d1b14d --- /dev/null +++ b/gimp-plugins/Inpainting/DFNet_core.py @@ -0,0 +1,266 @@ +import torch +from torch import nn +import torch.nn.functional as F +from builtins import * + +def resize_like(x, target, mode='bilinear'): + return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False) + +def get_norm(name, out_channels): + if name == 'batch': + norm = nn.BatchNorm2d(out_channels) + elif name == 'instance': + norm = nn.InstanceNorm2d(out_channels) + else: + norm = None + return norm + + +def get_activation(name): + if name == 'relu': + activation = nn.ReLU() + elif name == 'elu': + activation == nn.ELU() + elif name == 'leaky_relu': + activation = nn.LeakyReLU(negative_slope=0.2) + elif name == 'tanh': + activation = nn.Tanh() + elif name == 'sigmoid': + activation = nn.Sigmoid() + else: + activation = None + return activation + + +class Conv2dSame(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride): + super().__init__() + + padding = self.conv_same_pad(kernel_size, stride) + if type(padding) is not tuple: + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, stride, padding) + else: + self.conv = nn.Sequential( + nn.ConstantPad2d(padding*2, 0), + nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0) + ) + + def conv_same_pad(self, ksize, stride): + if (ksize - stride) % 2 == 0: + return (ksize - stride) // 2 + else: + left = (ksize - stride) // 2 + right = left + 1 + return left, right + + def forward(self, x): + return self.conv(x) + + +class ConvTranspose2dSame(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride): + super().__init__() + + padding, output_padding = self.deconv_same_pad(kernel_size, stride) + self.trans_conv = nn.ConvTranspose2d( + in_channels, out_channels, kernel_size, stride, + padding, output_padding) + + def deconv_same_pad(self, ksize, stride): + pad = (ksize - stride + 1) // 2 + outpad = 2 * pad + stride - ksize + return pad, outpad + + def forward(self, x): + return self.trans_conv(x) + + +class UpBlock(nn.Module): + + def __init__(self, mode='nearest', scale=2, channel=None, kernel_size=4): + super().__init__() + + self.mode = mode + if mode == 'deconv': + self.up = ConvTranspose2dSame( + channel, channel, kernel_size, stride=scale) + else: + def upsample(x): + return F.interpolate(x, scale_factor=scale, mode=mode) + self.up = upsample + + def forward(self, x): + return self.up(x) + + +class EncodeBlock(nn.Module): + + def __init__( + self, in_channels, out_channels, kernel_size, stride, + normalization=None, activation=None): + super().__init__() + + self.c_in = in_channels + self.c_out = out_channels + + layers = [] + layers.append( + Conv2dSame(self.c_in, self.c_out, kernel_size, stride)) + if normalization: + layers.append(get_norm(normalization, self.c_out)) + if activation: + layers.append(get_activation(activation)) + self.encode = nn.Sequential(*layers) + + def forward(self, x): + return self.encode(x) + + +class DecodeBlock(nn.Module): + + def __init__( + self, c_from_up, c_from_down, c_out, mode='nearest', + kernel_size=4, scale=2, normalization='batch', activation='relu'): + super().__init__() + + self.c_from_up = c_from_up + self.c_from_down = c_from_down + self.c_in = c_from_up + c_from_down + self.c_out = c_out + + self.up = UpBlock(mode, scale, c_from_up, kernel_size=scale) + + layers = [] + layers.append( + Conv2dSame(self.c_in, self.c_out, kernel_size, stride=1)) + if normalization: + layers.append(get_norm(normalization, self.c_out)) + if activation: + layers.append(get_activation(activation)) + self.decode = nn.Sequential(*layers) + + def forward(self, x, concat=None): + out = self.up(x) + if self.c_from_down > 0: + out = torch.cat([out, concat], dim=1) + out = self.decode(out) + return out + + +class BlendBlock(nn.Module): + + def __init__( + self, c_in, c_out, ksize_mid=3, norm='batch', act='leaky_relu'): + super().__init__() + c_mid = max(c_in // 2, 32) + self.blend = nn.Sequential( + Conv2dSame(c_in, c_mid, 1, 1), + get_norm(norm, c_mid), + get_activation(act), + Conv2dSame(c_mid, c_out, ksize_mid, 1), + get_norm(norm, c_out), + get_activation(act), + Conv2dSame(c_out, c_out, 1, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.blend(x) + + +class FusionBlock(nn.Module): + def __init__(self, c_feat, c_alpha=1): + super().__init__() + c_img = 3 + self.map2img = nn.Sequential( + Conv2dSame(c_feat, c_img, 1, 1), + nn.Sigmoid()) + self.blend = BlendBlock(c_img*2, c_alpha) + + def forward(self, img_miss, feat_de): + img_miss = resize_like(img_miss, feat_de) + raw = self.map2img(feat_de) + alpha = self.blend(torch.cat([img_miss, raw], dim=1)) + result = alpha * raw + (1 - alpha) * img_miss + return result, alpha, raw + + +class DFNet(nn.Module): + def __init__( + self, c_img=3, c_mask=1, c_alpha=3, + mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu', + en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8, + blend_layers=[0, 1, 2, 3, 4, 5]): + super().__init__() + + c_init = c_img + c_mask + + self.n_en = len(en_ksize) + self.n_de = len(de_ksize) + assert self.n_en == self.n_de, ( + 'The number layer of Encoder and Decoder must be equal.') + assert self.n_en >= 1, ( + 'The number layer of Encoder and Decoder must be greater than 1.') + + assert 0 in blend_layers, 'Layer 0 must be blended.' + + self.en = [] + c_in = c_init + self.en.append( + EncodeBlock(c_in, 64, en_ksize[0], 2, None, None)) + for k_en in en_ksize[1:]: + c_in = self.en[-1].c_out + c_out = min(c_in*2, 512) + self.en.append(EncodeBlock( + c_in, c_out, k_en, stride=2, + normalization=norm, activation=act_en)) + + # register parameters + for i, en in enumerate(self.en): + self.__setattr__('en_{}'.format(i), en) + + self.de = [] + self.fuse = [] + for i, k_de in enumerate(de_ksize): + + c_from_up = self.en[-1].c_out if i == 0 else self.de[-1].c_out + c_out = c_from_down = self.en[-i-1].c_in + layer_idx = self.n_de - i - 1 + + self.de.append(DecodeBlock( + c_from_up, c_from_down, c_out, mode, k_de, scale=2, + normalization=norm, activation=act_de)) + if layer_idx in blend_layers: + self.fuse.append(FusionBlock(c_out, c_alpha)) + else: + self.fuse.append(None) + + # register parameters + for i, de in enumerate(self.de[::-1]): + self.__setattr__('de_{}'.format(i), de) + for i, fuse in enumerate(self.fuse[::-1]): + if fuse: + self.__setattr__('fuse_{}'.format(i), fuse) + + def forward(self, img_miss, mask): + + out = torch.cat([img_miss, mask], dim=1) + + out_en = [out] + for encode in self.en: + out = encode(out) + out_en.append(out) + + results = [] + alphas = [] + raws = [] + for i, (decode, fuse) in enumerate(zip(self.de, self.fuse)): + out = decode(out, out_en[-i-2]) + if fuse: + result, alpha, raw = fuse(img_miss, out) + results.append(result) + + return results[::-1] diff --git a/gimp-plugins/Inpainting/RefinementNet_core.py b/gimp-plugins/Inpainting/RefinementNet_core.py new file mode 100644 index 0000000..c3bfb7a --- /dev/null +++ b/gimp-plugins/Inpainting/RefinementNet_core.py @@ -0,0 +1,71 @@ +import torch +from torch import nn +from DFNet_core import get_norm, get_activation, Conv2dSame, ConvTranspose2dSame, UpBlock, EncodeBlock, DecodeBlock +from builtins import * + + +class RefinementNet(nn.Module): + def __init__( + self, c_img=19, c_mask=1, + mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu', + en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8): + super(RefinementNet, self).__init__() + + c_in = c_img + c_mask + + self.en1 = EncodeBlock(c_in, 96, en_ksize[0], 2, None, None) + self.en2 = EncodeBlock(96, 192, en_ksize[1], stride=2, normalization=norm, activation=act_en) + self.en3 = EncodeBlock(192, 384, en_ksize[2], stride=2, normalization=norm, activation=act_en) + self.en4 = EncodeBlock(384, 512, en_ksize[3], stride=2, normalization=norm, activation=act_en) + self.en5 = EncodeBlock(512, 512, en_ksize[4], stride=2, normalization=norm, activation=act_en) + self.en6 = EncodeBlock(512, 512, en_ksize[5], stride=2, normalization=norm, activation=act_en) + self.en7 = EncodeBlock(512, 512, en_ksize[6], stride=2, normalization=norm, activation=act_en) + self.en8 = EncodeBlock(512, 512, en_ksize[7], stride=2, normalization=norm, activation=act_en) + + self.de1 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de2 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de3 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de4 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de5 = DecodeBlock(512, 384, 384, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de6 = DecodeBlock(384, 192, 192, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de7 = DecodeBlock(192, 96, 96, mode, 3, scale=2,normalization=norm, activation=act_de) + self.de8 = DecodeBlock(96, 20, 20, mode, 3, scale=2,normalization=norm, activation=act_de) + + self.last_conv = nn.Sequential(Conv2dSame(c_in, 3, 1, 1), nn.Sigmoid()) + + def forward(self, img, mask): + out = torch.cat([mask, img], dim=1) + out_en = [out] + + out = self.en1(out) + out_en.append(out) + out = self.en2(out) + out_en.append(out) + out = self.en3(out) + out_en.append(out) + out = self.en4(out) + out_en.append(out) + out = self.en5(out) + out_en.append(out) + out = self.en6(out) + out_en.append(out) + out = self.en7(out) + out_en.append(out) + out = self.en8(out) + out_en.append(out) + + + out = self.de1(out, out_en[-0-2]) + out = self.de2(out, out_en[-1-2]) + out = self.de3(out, out_en[-2-2]) + out = self.de4(out, out_en[-3-2]) + out = self.de5(out, out_en[-4-2]) + out = self.de6(out, out_en[-5-2]) + out = self.de7(out, out_en[-6-2]) + out = self.de8(out, out_en[-7-2]) + + output = self.last_conv(out) + + output = mask * output + (1 - mask) * img[:, :3] + + return output diff --git a/gimp-plugins/inpainting.py b/gimp-plugins/inpainting.py new file mode 100755 index 0000000..0b815a5 --- /dev/null +++ b/gimp-plugins/inpainting.py @@ -0,0 +1,191 @@ +from __future__ import division +import os +baseLoc = os.path.dirname(os.path.realpath(__file__))+'/' + +from gimpfu import * +import sys + +sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'Inpainting']) + + +import torch +import numpy as np +from torch import nn +import scipy.ndimage +import cv2 +from DFNet_core import DFNet +from RefinementNet_core import RefinementNet + + + +def channelData(layer):#convert gimp image to numpy + region=layer.get_pixel_rgn(0, 0, layer.width,layer.height) + pixChars=region[:,:] # Take whole layer + bpp=region.bpp + # return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp) + return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp) + +def createResultLayer(image,name,result): + rlBytes=np.uint8(result).tobytes(); + rl=gimp.Layer(image,name,image.width,image.height,0,100,NORMAL_MODE) + region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True) + region[:,:]=rlBytes + image.add_layer(rl,0) + gimp.displays_flush() + + +def to_numpy(tensor): + tensor = tensor.mul(255).byte().data.cpu().numpy() + tensor = np.transpose(tensor, [0, 2, 3, 1]) + return tensor + +def padding(img, height=512, width=512, channels=3): + channels = img.shape[2] if len(img.shape) > 2 else 1 + interpolation=cv2.INTER_NEAREST + + if channels == 1: + img_padded = np.zeros((height, width), dtype=img.dtype) + else: + img_padded = np.zeros((height, width, channels), dtype=img.dtype) + + original_shape = img.shape + rows_rate = original_shape[0] / height + cols_rate = original_shape[1] / width + new_cols = width + new_rows = height + if rows_rate > cols_rate: + new_cols = (original_shape[1] * height) // original_shape[0] + img = cv2.resize(img, (new_cols, height), interpolation=interpolation) + if new_cols > width: + new_cols = width + img_padded[:, ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img + else: + new_rows = (original_shape[0] * width) // original_shape[1] + img = cv2.resize(img, (width, new_rows), interpolation=interpolation) + if new_rows > height: + new_rows = height + img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), :] = img + return img_padded, new_cols, new_rows + + + +def preprocess_image_dfnet(image, mask, model,device): + image, new_cols, new_rows = padding(image, 512, 512) + mask, _, _ = padding(mask, 512, 512) + image = np.ascontiguousarray(image.transpose(2, 0, 1)).astype(np.uint8) + mask = np.ascontiguousarray(np.expand_dims(mask, 0)).astype(np.uint8) + + image = torch.from_numpy(image).to(device).float().div(255) + mask = 1 - torch.from_numpy(mask).to(device).float().div(255) + image_miss = image * mask + DFNET_output = model(image_miss.unsqueeze(0), mask.unsqueeze(0))[0] + DFNET_output = image * mask + DFNET_output * (1 - mask) + DFNET_output = to_numpy(DFNET_output)[0] + DFNET_output = cv2.cvtColor(DFNET_output, cv2.COLOR_BGR2RGB) + DFNET_output = DFNET_output[(DFNET_output.shape[0] - new_rows) // 2: (DFNET_output.shape[0] - new_rows) // 2 + new_rows, + (DFNET_output.shape[1] - new_cols) // 2: (DFNET_output.shape[1] - new_cols) // 2 + new_cols, ...] + + return DFNET_output + + + +def preprocess_image(image, mask, image_before_resize, model,device): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + shift_val = (100 / 512) * image.shape[0] + + image_resized = cv2.resize(image_before_resize, (image.shape[1], image.shape[0])) + + mask = mask // 255 + image_matched = image * (1 - mask) + image_resized * mask + mask = mask * 255 + + img_1 = scipy.ndimage.shift(image_matched, (-shift_val, 0, 0), order=0, mode='constant', cval=1) + mask_1 = scipy.ndimage.shift(mask, (-shift_val, 0, 0), order=0, mode='constant', cval=255) + img_2 = scipy.ndimage.shift(image_matched, (shift_val, 0, 0), order=0, mode='constant', cval=1) + mask_2 = scipy.ndimage.shift(mask, (shift_val, 0, 0), order=0, mode='constant', cval=255) + img_3 = scipy.ndimage.shift(image_matched, (0, shift_val, 0), order=0, mode='constant', cval=1) + mask_3 = scipy.ndimage.shift(mask, (0, shift_val, 0), order=0, mode='constant', cval=255) + img_4 = scipy.ndimage.shift(image_matched, (0, -shift_val, 0), order=0, mode='constant', cval=1) + mask_4 = scipy.ndimage.shift(mask, (0, -shift_val, 0), order=0, mode='constant', cval=255) + image_cat = np.dstack((mask, image_matched, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4)) + + mask_patch = torch.from_numpy(image_cat).to(device).float().div(255).unsqueeze(0) + mask_patch = mask_patch.permute(0, -1, 1, 2) + inputs = mask_patch[:, 1:, ...] + mask = mask_patch[:, 0:1, ...] + out = model(inputs, mask) + out = out.mul(255).byte().data.cpu().numpy() + out = np.transpose(out, [0, 2, 3, 1])[0] + + return out + + +def pad_image(image): + x = ((image.shape[0] // 256) + (1 if image.shape[0] % 256 != 0 else 0)) * 256 + y = ((image.shape[1] // 256) + (1 if image.shape[1] % 256 != 0 else 0)) * 256 + padded = np.zeros((x, y, image.shape[2]), dtype='uint8') + padded[:image.shape[0], :image.shape[1], ...] = image + return padded + + +def inpaint(imggimp, curlayer,layeri,layerm,cFlag) : + + img = channelData(layeri)[..., :3] + mask = channelData(layerm)[..., :3] + + if img.shape[0] != imggimp.height or img.shape[1] != imggimp.width or mask.shape[0] != imggimp.height or mask.shape[1] != imggimp.width: + pdb.gimp_message(" Do (Layer -> Layer to Image Size) first and try again.") + else: + if torch.cuda.is_available() and not cFlag: + gimp.progress_init("(Using GPU) Running inpainting for " + layeri.name + "...") + device = torch.device('cuda') + else: + gimp.progress_init("(Using CPU) Running inpainting for " + layeri.name + "...") + device = torch.device('cpu') + + assert img.shape[:2] == mask.shape[:2] + + mask = mask[..., :1] + + + image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + shape = image.shape + + image = pad_image(image) + mask = pad_image(mask) + + DFNet_model = DFNet().to(device) + DFNet_model.load_state_dict(torch.load(baseLoc + '/weights/inpainting/model_places2.pth', map_location=device)) + DFNet_model.eval() + DFNET_output = preprocess_image_dfnet(image, mask, DFNet_model,device) + del DFNet_model + Refinement_model = RefinementNet().to(device) + Refinement_model.load_state_dict(torch.load(baseLoc+'/weights/inpainting/refinement.pth', map_location=device)['state_dict']) + Refinement_model.eval() + out = preprocess_image(image, mask, DFNET_output, Refinement_model,device) + out = out[:shape[0], :shape[1], ...] + del Refinement_model + createResultLayer(imggimp,'output',out) + + + +register( + "inpainting", + "inpainting", + "Running inpainting.", + "Andrey Moskalenko", + "Your", + "2020", + "inpainting...", + "*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc. + [ (PF_IMAGE, "image", "Input image", None), + (PF_DRAWABLE, "drawable", "Input drawable", None), + (PF_LAYER, "drawinglayer", "Image:", None), + (PF_LAYER, "drawinglayer", "Mask:", None), + (PF_BOOL, "fcpu", "Force CPU", False) + ], + [], + inpaint, menu="/Layer/GIML-ML") + +main() diff --git a/gimp-plugins/syncWeights.py b/gimp-plugins/syncWeights.py index 122d8de..c889196 100755 --- a/gimp-plugins/syncWeights.py +++ b/gimp-plugins/syncWeights.py @@ -241,5 +241,22 @@ def sync(path,flag): gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...") download_file_from_google_drive(file_id, destination,fileSize) - + #inpainting + model = 'inpainting' + file_id = '1WmPevEnRcdUynVHL8pZNzHPmLQVCFjuE' + fileSize = 132 #in MB + mFName = 'model_places2.pth' + if not os.path.isdir(path + '/' + model): + os.mkdir(path + '/' + model) + destination = path + '/' + model + '/' + mFName + if not os.path.isfile(destination): + gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...") + download_file_from_google_drive(file_id, destination,fileSize) + file_id = '1hIcPqDp8JjzR5kmt275DaVgX2PEtahWS' + fileSize = 148 #in MB + mFName = 'refinement.pth' + destination = path + '/' + model + '/' + mFName + if not os.path.isfile(destination): + gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...") + download_file_from_google_drive(file_id, destination,fileSize)