added inpainting

pull/33/head
Kritik Soman 3 years ago
parent 1ba66386d7
commit dc0205c5a4

@ -0,0 +1,266 @@
import torch
from torch import nn
import torch.nn.functional as F
from builtins import *
def resize_like(x, target, mode='bilinear'):
return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)
def get_norm(name, out_channels):
if name == 'batch':
norm = nn.BatchNorm2d(out_channels)
elif name == 'instance':
norm = nn.InstanceNorm2d(out_channels)
else:
norm = None
return norm
def get_activation(name):
if name == 'relu':
activation = nn.ReLU()
elif name == 'elu':
activation == nn.ELU()
elif name == 'leaky_relu':
activation = nn.LeakyReLU(negative_slope=0.2)
elif name == 'tanh':
activation = nn.Tanh()
elif name == 'sigmoid':
activation = nn.Sigmoid()
else:
activation = None
return activation
class Conv2dSame(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
padding = self.conv_same_pad(kernel_size, stride)
if type(padding) is not tuple:
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding)
else:
self.conv = nn.Sequential(
nn.ConstantPad2d(padding*2, 0),
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0)
)
def conv_same_pad(self, ksize, stride):
if (ksize - stride) % 2 == 0:
return (ksize - stride) // 2
else:
left = (ksize - stride) // 2
right = left + 1
return left, right
def forward(self, x):
return self.conv(x)
class ConvTranspose2dSame(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
padding, output_padding = self.deconv_same_pad(kernel_size, stride)
self.trans_conv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size, stride,
padding, output_padding)
def deconv_same_pad(self, ksize, stride):
pad = (ksize - stride + 1) // 2
outpad = 2 * pad + stride - ksize
return pad, outpad
def forward(self, x):
return self.trans_conv(x)
class UpBlock(nn.Module):
def __init__(self, mode='nearest', scale=2, channel=None, kernel_size=4):
super().__init__()
self.mode = mode
if mode == 'deconv':
self.up = ConvTranspose2dSame(
channel, channel, kernel_size, stride=scale)
else:
def upsample(x):
return F.interpolate(x, scale_factor=scale, mode=mode)
self.up = upsample
def forward(self, x):
return self.up(x)
class EncodeBlock(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, stride,
normalization=None, activation=None):
super().__init__()
self.c_in = in_channels
self.c_out = out_channels
layers = []
layers.append(
Conv2dSame(self.c_in, self.c_out, kernel_size, stride))
if normalization:
layers.append(get_norm(normalization, self.c_out))
if activation:
layers.append(get_activation(activation))
self.encode = nn.Sequential(*layers)
def forward(self, x):
return self.encode(x)
class DecodeBlock(nn.Module):
def __init__(
self, c_from_up, c_from_down, c_out, mode='nearest',
kernel_size=4, scale=2, normalization='batch', activation='relu'):
super().__init__()
self.c_from_up = c_from_up
self.c_from_down = c_from_down
self.c_in = c_from_up + c_from_down
self.c_out = c_out
self.up = UpBlock(mode, scale, c_from_up, kernel_size=scale)
layers = []
layers.append(
Conv2dSame(self.c_in, self.c_out, kernel_size, stride=1))
if normalization:
layers.append(get_norm(normalization, self.c_out))
if activation:
layers.append(get_activation(activation))
self.decode = nn.Sequential(*layers)
def forward(self, x, concat=None):
out = self.up(x)
if self.c_from_down > 0:
out = torch.cat([out, concat], dim=1)
out = self.decode(out)
return out
class BlendBlock(nn.Module):
def __init__(
self, c_in, c_out, ksize_mid=3, norm='batch', act='leaky_relu'):
super().__init__()
c_mid = max(c_in // 2, 32)
self.blend = nn.Sequential(
Conv2dSame(c_in, c_mid, 1, 1),
get_norm(norm, c_mid),
get_activation(act),
Conv2dSame(c_mid, c_out, ksize_mid, 1),
get_norm(norm, c_out),
get_activation(act),
Conv2dSame(c_out, c_out, 1, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.blend(x)
class FusionBlock(nn.Module):
def __init__(self, c_feat, c_alpha=1):
super().__init__()
c_img = 3
self.map2img = nn.Sequential(
Conv2dSame(c_feat, c_img, 1, 1),
nn.Sigmoid())
self.blend = BlendBlock(c_img*2, c_alpha)
def forward(self, img_miss, feat_de):
img_miss = resize_like(img_miss, feat_de)
raw = self.map2img(feat_de)
alpha = self.blend(torch.cat([img_miss, raw], dim=1))
result = alpha * raw + (1 - alpha) * img_miss
return result, alpha, raw
class DFNet(nn.Module):
def __init__(
self, c_img=3, c_mask=1, c_alpha=3,
mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8,
blend_layers=[0, 1, 2, 3, 4, 5]):
super().__init__()
c_init = c_img + c_mask
self.n_en = len(en_ksize)
self.n_de = len(de_ksize)
assert self.n_en == self.n_de, (
'The number layer of Encoder and Decoder must be equal.')
assert self.n_en >= 1, (
'The number layer of Encoder and Decoder must be greater than 1.')
assert 0 in blend_layers, 'Layer 0 must be blended.'
self.en = []
c_in = c_init
self.en.append(
EncodeBlock(c_in, 64, en_ksize[0], 2, None, None))
for k_en in en_ksize[1:]:
c_in = self.en[-1].c_out
c_out = min(c_in*2, 512)
self.en.append(EncodeBlock(
c_in, c_out, k_en, stride=2,
normalization=norm, activation=act_en))
# register parameters
for i, en in enumerate(self.en):
self.__setattr__('en_{}'.format(i), en)
self.de = []
self.fuse = []
for i, k_de in enumerate(de_ksize):
c_from_up = self.en[-1].c_out if i == 0 else self.de[-1].c_out
c_out = c_from_down = self.en[-i-1].c_in
layer_idx = self.n_de - i - 1
self.de.append(DecodeBlock(
c_from_up, c_from_down, c_out, mode, k_de, scale=2,
normalization=norm, activation=act_de))
if layer_idx in blend_layers:
self.fuse.append(FusionBlock(c_out, c_alpha))
else:
self.fuse.append(None)
# register parameters
for i, de in enumerate(self.de[::-1]):
self.__setattr__('de_{}'.format(i), de)
for i, fuse in enumerate(self.fuse[::-1]):
if fuse:
self.__setattr__('fuse_{}'.format(i), fuse)
def forward(self, img_miss, mask):
out = torch.cat([img_miss, mask], dim=1)
out_en = [out]
for encode in self.en:
out = encode(out)
out_en.append(out)
results = []
alphas = []
raws = []
for i, (decode, fuse) in enumerate(zip(self.de, self.fuse)):
out = decode(out, out_en[-i-2])
if fuse:
result, alpha, raw = fuse(img_miss, out)
results.append(result)
return results[::-1]

@ -0,0 +1,71 @@
import torch
from torch import nn
from DFNet_core import get_norm, get_activation, Conv2dSame, ConvTranspose2dSame, UpBlock, EncodeBlock, DecodeBlock
from builtins import *
class RefinementNet(nn.Module):
def __init__(
self, c_img=19, c_mask=1,
mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8):
super(RefinementNet, self).__init__()
c_in = c_img + c_mask
self.en1 = EncodeBlock(c_in, 96, en_ksize[0], 2, None, None)
self.en2 = EncodeBlock(96, 192, en_ksize[1], stride=2, normalization=norm, activation=act_en)
self.en3 = EncodeBlock(192, 384, en_ksize[2], stride=2, normalization=norm, activation=act_en)
self.en4 = EncodeBlock(384, 512, en_ksize[3], stride=2, normalization=norm, activation=act_en)
self.en5 = EncodeBlock(512, 512, en_ksize[4], stride=2, normalization=norm, activation=act_en)
self.en6 = EncodeBlock(512, 512, en_ksize[5], stride=2, normalization=norm, activation=act_en)
self.en7 = EncodeBlock(512, 512, en_ksize[6], stride=2, normalization=norm, activation=act_en)
self.en8 = EncodeBlock(512, 512, en_ksize[7], stride=2, normalization=norm, activation=act_en)
self.de1 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de2 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de3 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de4 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de5 = DecodeBlock(512, 384, 384, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de6 = DecodeBlock(384, 192, 192, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de7 = DecodeBlock(192, 96, 96, mode, 3, scale=2,normalization=norm, activation=act_de)
self.de8 = DecodeBlock(96, 20, 20, mode, 3, scale=2,normalization=norm, activation=act_de)
self.last_conv = nn.Sequential(Conv2dSame(c_in, 3, 1, 1), nn.Sigmoid())
def forward(self, img, mask):
out = torch.cat([mask, img], dim=1)
out_en = [out]
out = self.en1(out)
out_en.append(out)
out = self.en2(out)
out_en.append(out)
out = self.en3(out)
out_en.append(out)
out = self.en4(out)
out_en.append(out)
out = self.en5(out)
out_en.append(out)
out = self.en6(out)
out_en.append(out)
out = self.en7(out)
out_en.append(out)
out = self.en8(out)
out_en.append(out)
out = self.de1(out, out_en[-0-2])
out = self.de2(out, out_en[-1-2])
out = self.de3(out, out_en[-2-2])
out = self.de4(out, out_en[-3-2])
out = self.de5(out, out_en[-4-2])
out = self.de6(out, out_en[-5-2])
out = self.de7(out, out_en[-6-2])
out = self.de8(out, out_en[-7-2])
output = self.last_conv(out)
output = mask * output + (1 - mask) * img[:, :3]
return output

@ -0,0 +1,191 @@
from __future__ import division
import os
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
from gimpfu import *
import sys
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'Inpainting'])
import torch
import numpy as np
from torch import nn
import scipy.ndimage
import cv2
from DFNet_core import DFNet
from RefinementNet_core import RefinementNet
def channelData(layer):#convert gimp image to numpy
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
pixChars=region[:,:] # Take whole layer
bpp=region.bpp
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
def createResultLayer(image,name,result):
rlBytes=np.uint8(result).tobytes();
rl=gimp.Layer(image,name,image.width,image.height,0,100,NORMAL_MODE)
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
region[:,:]=rlBytes
image.add_layer(rl,0)
gimp.displays_flush()
def to_numpy(tensor):
tensor = tensor.mul(255).byte().data.cpu().numpy()
tensor = np.transpose(tensor, [0, 2, 3, 1])
return tensor
def padding(img, height=512, width=512, channels=3):
channels = img.shape[2] if len(img.shape) > 2 else 1
interpolation=cv2.INTER_NEAREST
if channels == 1:
img_padded = np.zeros((height, width), dtype=img.dtype)
else:
img_padded = np.zeros((height, width, channels), dtype=img.dtype)
original_shape = img.shape
rows_rate = original_shape[0] / height
cols_rate = original_shape[1] / width
new_cols = width
new_rows = height
if rows_rate > cols_rate:
new_cols = (original_shape[1] * height) // original_shape[0]
img = cv2.resize(img, (new_cols, height), interpolation=interpolation)
if new_cols > width:
new_cols = width
img_padded[:, ((img_padded.shape[1] - new_cols) // 2):((img_padded.shape[1] - new_cols) // 2 + new_cols)] = img
else:
new_rows = (original_shape[0] * width) // original_shape[1]
img = cv2.resize(img, (width, new_rows), interpolation=interpolation)
if new_rows > height:
new_rows = height
img_padded[((img_padded.shape[0] - new_rows) // 2):((img_padded.shape[0] - new_rows) // 2 + new_rows), :] = img
return img_padded, new_cols, new_rows
def preprocess_image_dfnet(image, mask, model,device):
image, new_cols, new_rows = padding(image, 512, 512)
mask, _, _ = padding(mask, 512, 512)
image = np.ascontiguousarray(image.transpose(2, 0, 1)).astype(np.uint8)
mask = np.ascontiguousarray(np.expand_dims(mask, 0)).astype(np.uint8)
image = torch.from_numpy(image).to(device).float().div(255)
mask = 1 - torch.from_numpy(mask).to(device).float().div(255)
image_miss = image * mask
DFNET_output = model(image_miss.unsqueeze(0), mask.unsqueeze(0))[0]
DFNET_output = image * mask + DFNET_output * (1 - mask)
DFNET_output = to_numpy(DFNET_output)[0]
DFNET_output = cv2.cvtColor(DFNET_output, cv2.COLOR_BGR2RGB)
DFNET_output = DFNET_output[(DFNET_output.shape[0] - new_rows) // 2: (DFNET_output.shape[0] - new_rows) // 2 + new_rows,
(DFNET_output.shape[1] - new_cols) // 2: (DFNET_output.shape[1] - new_cols) // 2 + new_cols, ...]
return DFNET_output
def preprocess_image(image, mask, image_before_resize, model,device):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
shift_val = (100 / 512) * image.shape[0]
image_resized = cv2.resize(image_before_resize, (image.shape[1], image.shape[0]))
mask = mask // 255
image_matched = image * (1 - mask) + image_resized * mask
mask = mask * 255
img_1 = scipy.ndimage.shift(image_matched, (-shift_val, 0, 0), order=0, mode='constant', cval=1)
mask_1 = scipy.ndimage.shift(mask, (-shift_val, 0, 0), order=0, mode='constant', cval=255)
img_2 = scipy.ndimage.shift(image_matched, (shift_val, 0, 0), order=0, mode='constant', cval=1)
mask_2 = scipy.ndimage.shift(mask, (shift_val, 0, 0), order=0, mode='constant', cval=255)
img_3 = scipy.ndimage.shift(image_matched, (0, shift_val, 0), order=0, mode='constant', cval=1)
mask_3 = scipy.ndimage.shift(mask, (0, shift_val, 0), order=0, mode='constant', cval=255)
img_4 = scipy.ndimage.shift(image_matched, (0, -shift_val, 0), order=0, mode='constant', cval=1)
mask_4 = scipy.ndimage.shift(mask, (0, -shift_val, 0), order=0, mode='constant', cval=255)
image_cat = np.dstack((mask, image_matched, img_1, mask_1, img_2, mask_2, img_3, mask_3, img_4, mask_4))
mask_patch = torch.from_numpy(image_cat).to(device).float().div(255).unsqueeze(0)
mask_patch = mask_patch.permute(0, -1, 1, 2)
inputs = mask_patch[:, 1:, ...]
mask = mask_patch[:, 0:1, ...]
out = model(inputs, mask)
out = out.mul(255).byte().data.cpu().numpy()
out = np.transpose(out, [0, 2, 3, 1])[0]
return out
def pad_image(image):
x = ((image.shape[0] // 256) + (1 if image.shape[0] % 256 != 0 else 0)) * 256
y = ((image.shape[1] // 256) + (1 if image.shape[1] % 256 != 0 else 0)) * 256
padded = np.zeros((x, y, image.shape[2]), dtype='uint8')
padded[:image.shape[0], :image.shape[1], ...] = image
return padded
def inpaint(imggimp, curlayer,layeri,layerm,cFlag) :
img = channelData(layeri)[..., :3]
mask = channelData(layerm)[..., :3]
if img.shape[0] != imggimp.height or img.shape[1] != imggimp.width or mask.shape[0] != imggimp.height or mask.shape[1] != imggimp.width:
pdb.gimp_message(" Do (Layer -> Layer to Image Size) first and try again.")
else:
if torch.cuda.is_available() and not cFlag:
gimp.progress_init("(Using GPU) Running inpainting for " + layeri.name + "...")
device = torch.device('cuda')
else:
gimp.progress_init("(Using CPU) Running inpainting for " + layeri.name + "...")
device = torch.device('cpu')
assert img.shape[:2] == mask.shape[:2]
mask = mask[..., :1]
image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
shape = image.shape
image = pad_image(image)
mask = pad_image(mask)
DFNet_model = DFNet().to(device)
DFNet_model.load_state_dict(torch.load(baseLoc + '/weights/inpainting/model_places2.pth', map_location=device))
DFNet_model.eval()
DFNET_output = preprocess_image_dfnet(image, mask, DFNet_model,device)
del DFNet_model
Refinement_model = RefinementNet().to(device)
Refinement_model.load_state_dict(torch.load(baseLoc+'/weights/inpainting/refinement.pth', map_location=device)['state_dict'])
Refinement_model.eval()
out = preprocess_image(image, mask, DFNET_output, Refinement_model,device)
out = out[:shape[0], :shape[1], ...]
del Refinement_model
createResultLayer(imggimp,'output',out)
register(
"inpainting",
"inpainting",
"Running inpainting.",
"Andrey Moskalenko",
"Your",
"2020",
"inpainting...",
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
[ (PF_IMAGE, "image", "Input image", None),
(PF_DRAWABLE, "drawable", "Input drawable", None),
(PF_LAYER, "drawinglayer", "Image:", None),
(PF_LAYER, "drawinglayer", "Mask:", None),
(PF_BOOL, "fcpu", "Force CPU", False)
],
[],
inpaint, menu="<Image>/Layer/GIML-ML")
main()

@ -241,5 +241,22 @@ def sync(path,flag):
gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...")
download_file_from_google_drive(file_id, destination,fileSize)
#inpainting
model = 'inpainting'
file_id = '1WmPevEnRcdUynVHL8pZNzHPmLQVCFjuE'
fileSize = 132 #in MB
mFName = 'model_places2.pth'
if not os.path.isdir(path + '/' + model):
os.mkdir(path + '/' + model)
destination = path + '/' + model + '/' + mFName
if not os.path.isfile(destination):
gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...")
download_file_from_google_drive(file_id, destination,fileSize)
file_id = '1hIcPqDp8JjzR5kmt275DaVgX2PEtahWS'
fileSize = 148 #in MB
mFName = 'refinement.pth'
destination = path + '/' + model + '/' + mFName
if not os.path.isfile(destination):
gimp.progress_init("Downloading " + model +"(~" + str(fileSize) + "MB)...")
download_file_from_google_drive(file_id, destination,fileSize)

Loading…
Cancel
Save