You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
GIMP-ML/gimp-plugins/ideepcolor/data/colorize_image.py

566 lines
22 KiB
Python

import numpy as np
# import matplotlib.pyplot as plt
# from skimage import color
# from sklearn.cluster import KMeans
import os
import cv2
from scipy.ndimage.interpolation import zoom
def create_temp_directory(path_template, N=1e8):
print(path_template)
cur_path = path_template % np.random.randint(0, N)
while(os.path.exists(cur_path)):
cur_path = path_template % np.random.randint(0, N)
print('Creating directory: %s' % cur_path)
os.mkdir(cur_path)
return cur_path
def lab2rgb_transpose(img_l, img_ab):
''' INPUTS
img_l 1xXxX [0,100]
img_ab 2xXxX [-100,100]
OUTPUTS
returned value is XxXx3 '''
pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0))
# im = color.lab2rgb(pred_lab)
im = cv2.cvtColor(pred_lab.astype('float32'),cv2.COLOR_LAB2RGB)
pred_rgb = (np.clip(im, 0, 1) * 255).astype('uint8')
return pred_rgb
def rgb2lab_transpose(img_rgb):
''' INPUTS
img_rgb XxXx3
OUTPUTS
returned value is 3xXxX '''
# im=color.rgb2lab(img_rgb)
im = cv2.cvtColor(img_rgb.astype(np.float32)/255, cv2.COLOR_RGB2LAB)
return im.transpose((2, 0, 1))
class ColorizeImageBase():
def __init__(self, Xd=256, Xfullres_max=10000):
self.Xd = Xd
self.img_l_set = False
self.net_set = False
self.Xfullres_max = Xfullres_max # maximum size of maximum dimension
self.img_just_set = False # this will be true whenever image is just loaded
# net_forward can set this to False if they want
def prep_net(self):
raise Exception("Should be implemented by base class")
# ***** Image prepping *****
def load_image(self, im):
# rgb image [CxXdxXd]
self.img_rgb_fullres = im.copy()
self._set_img_lab_fullres_()
im = cv2.resize(im, (self.Xd, self.Xd))
self.img_rgb = im.copy()
# self.img_rgb = sp.misc.imresize(plt.imread(input_path),(self.Xd,self.Xd)).transpose((2,0,1))
self.img_l_set = True
# convert into lab space
self._set_img_lab_()
self._set_img_lab_mc_()
def set_image(self, input_image):
self.img_rgb_fullres = input_image.copy()
self._set_img_lab_fullres_()
self.img_l_set = True
self.img_rgb = input_image
# convert into lab space
self._set_img_lab_()
self._set_img_lab_mc_()
def net_forward(self, input_ab, input_mask):
# INPUTS
# ab 2xXxX input color patches (non-normalized)
# mask 1xXxX input mask, indicating which points have been provided
# assumes self.img_l_mc has been set
if(not self.img_l_set):
print('I need to have an image!')
return -1
if(not self.net_set):
print('I need to have a net!')
return -1
self.input_ab = input_ab
self.input_ab_mc = (input_ab - self.ab_mean) / self.ab_norm
self.input_mask = input_mask
self.input_mask_mult = input_mask * self.mask_mult
return 0
def get_result_PSNR(self, result=-1, return_SE_map=False):
if np.array((result)).flatten()[0] == -1:
cur_result = self.get_img_forward()
else:
cur_result = result.copy()
SE_map = (1. * self.img_rgb - cur_result)**2
cur_MSE = np.mean(SE_map)
cur_PSNR = 20 * np.log10(255. / np.sqrt(cur_MSE))
if return_SE_map:
return(cur_PSNR, SE_map)
else:
return cur_PSNR
def get_img_forward(self):
# get image with point estimate
return self.output_rgb
def get_img_gray(self):
# Get black and white image
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
def get_img_gray_fullres(self):
# Get black and white image
return lab2rgb_transpose(self.img_l_fullres, np.zeros((2, self.img_l_fullres.shape[1], self.img_l_fullres.shape[2])))
def get_img_fullres(self):
# This assumes self.img_l_fullres, self.output_ab are set.
# Typically, this means that set_image() and net_forward()
# have been called.
# bilinear upsample
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
output_ab_fullres = zoom(self.output_ab, zoom_factor, order=1)
return lab2rgb_transpose(self.img_l_fullres, output_ab_fullres)
def get_input_img_fullres(self):
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=1)
return lab2rgb_transpose(self.img_l_fullres, input_ab_fullres)
def get_input_img(self):
return lab2rgb_transpose(self.img_l, self.input_ab)
def get_img_mask(self):
# Get black and white image
return lab2rgb_transpose(100. * (1 - self.input_mask), np.zeros((2, self.Xd, self.Xd)))
def get_img_mask_fullres(self):
# Get black and white image
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
return lab2rgb_transpose(100. * (1 - input_mask_fullres), np.zeros((2, input_mask_fullres.shape[1], input_mask_fullres.shape[2])))
def get_sup_img(self):
return lab2rgb_transpose(50 * self.input_mask, self.input_ab)
def get_sup_fullres(self):
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=0)
return lab2rgb_transpose(50 * input_mask_fullres, input_ab_fullres)
# ***** Private functions *****
def _set_img_lab_fullres_(self):
# adjust full resolution image to be within maximum dimension is within Xfullres_max
Xfullres = self.img_rgb_fullres.shape[0]
Yfullres = self.img_rgb_fullres.shape[1]
if Xfullres > self.Xfullres_max or Yfullres > self.Xfullres_max:
if Xfullres > Yfullres:
zoom_factor = 1. * self.Xfullres_max / Xfullres
else:
zoom_factor = 1. * self.Xfullres_max / Yfullres
self.img_rgb_fullres = zoom(self.img_rgb_fullres, (zoom_factor, zoom_factor, 1), order=1)
self.img_lab_fullres = cv2.cvtColor(self.img_rgb_fullres.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
# self.img_lab_fullres = color.rgb2lab(self.img_rgb_fullres).transpose((2, 0, 1))
self.img_l_fullres = self.img_lab_fullres[[0], :, :]
self.img_ab_fullres = self.img_lab_fullres[1:, :, :]
def _set_img_lab_(self):
# set self.img_lab from self.im_rgb
self.img_lab = cv2.cvtColor(self.img_rgb.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
# self.img_lab = color.rgb2lab(self.img_rgb).transpose((2, 0, 1))
self.img_l = self.img_lab[[0], :, :]
self.img_ab = self.img_lab[1:, :, :]
def _set_img_lab_mc_(self):
# set self.img_lab_mc from self.img_lab
# lab image, mean centered [XxYxX]
self.img_lab_mc = self.img_lab / np.array((self.l_norm, self.ab_norm, self.ab_norm))[:, np.newaxis, np.newaxis] - np.array(
(self.l_mean / self.l_norm, self.ab_mean / self.ab_norm, self.ab_mean / self.ab_norm))[:, np.newaxis, np.newaxis]
self._set_img_l_()
def _set_img_l_(self):
self.img_l_mc = self.img_lab_mc[[0], :, :]
self.img_l_set = True
def _set_img_ab_(self):
self.img_ab_mc = self.img_lab_mc[[1, 2], :, :]
def _set_out_ab_(self):
self.output_lab = rgb2lab_transpose(self.output_rgb)
self.output_ab = self.output_lab[1:, :, :]
class ColorizeImageTorch(ColorizeImageBase):
def __init__(self, Xd=256, maskcent=False):
print('ColorizeImageTorch instantiated')
ColorizeImageBase.__init__(self, Xd)
self.l_norm = 1.
self.ab_norm = 1.
self.l_mean = 50.
self.ab_mean = 0.
self.mask_mult = 1.
self.mask_cent = .5 if maskcent else 0
# Load grid properties
self.pts_in_hull = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
# ***** Net preparation *****
def prep_net(self, gpu_id=None, path='', dist=False):
import torch
import models.pytorch.model as model
print('path = %s' % path)
print('Model set! dist mode? ', dist)
self.net = model.SIGGRAPHGenerator(dist=dist)
state_dict = torch.load(path)
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.'))
self.net.load_state_dict(state_dict)
if gpu_id != None:
self.net.cuda()
self.net.eval()
self.net_set = True
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# ***** Call forward *****
def net_forward(self, input_ab, input_mask):
# INPUTS
# ab 2xXxX input color patches (non-normalized)
# mask 1xXxX input mask, indicating which points have been provided
# assumes self.img_l_mc has been set
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
return -1
# net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
# return prediction
# self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
# embed()
output_ab = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)[0, :, :, :].cpu().data.numpy()
self.output_rgb = lab2rgb_transpose(self.img_l, output_ab)
# self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
self._set_out_ab_()
return self.output_rgb
def get_img_forward(self):
# get image with point estimate
return self.output_rgb
def get_img_gray(self):
# Get black and white image
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
class ColorizeImageTorchDist(ColorizeImageTorch):
def __init__(self, Xd=256, maskcent=False):
ColorizeImageTorch.__init__(self, Xd)
self.dist_ab_set = False
self.pts_grid = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
self.in_hull = np.ones(529, dtype=bool)
self.AB = self.pts_grid.shape[0] # 529
self.A = int(np.sqrt(self.AB)) # 23
self.B = int(np.sqrt(self.AB)) # 23
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
self.dist_entropy = np.zeros((self.Xd, self.Xd))
self.mask_cent = .5 if maskcent else 0
def prep_net(self, gpu_id=None, path='', dist=True, S=.2):
ColorizeImageTorch.prep_net(self, gpu_id=gpu_id, path=path, dist=dist)
# set S somehow
def net_forward(self, input_ab, input_mask):
# INPUTS
# ab 2xXxX input color patches (non-normalized)
# mask 1xXxX input mask, indicating which points have been provided
# assumes self.img_l_mc has been set
# embed()
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
return -1
# set distribution
(function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)
function_return = function_return[0, :, :, :].cpu().data.numpy()
self.dist_ab = self.dist_ab[0, :, :, :].cpu().data.numpy()
self.dist_ab_set = True
# full grid, ABxXxX, AB = 529
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
# gridded, AxBxXxX, A = 23
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
# return
return function_return
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
# ''' Recommended colors at point (h,w)
# Call this after calling net_forward
# '''
# if not self.dist_ab_set:
# print('Need to set prediction first')
# return 0
#
# # randomly sample from pdf
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
# cmf = cmf / cmf[-1]
# cmf_bins = cmf
#
# # randomly sample N points
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
# inds = np.digitize(rnd_pts, bins=cmf_bins)
# rnd_pts_ab = self.pts_in_hull[inds, :]
#
# # run k-means
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
#
# # sort by cluster occupancy
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
#
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
#
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
# if return_conf:
# return cluster_centers, cluster_per
# else:
# return cluster_centers
def compute_entropy(self):
# compute the distribution entropy (really slow right now)
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
# def plot_dist_grid(self, h, w):
# # Plots distribution at a given point
# plt.figure()
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
# plt.colorbar()
# plt.ylabel('a')
# plt.xlabel('b')
# def plot_dist_entropy(self):
# # Plots distribution at a given point
# plt.figure()
# plt.imshow(-self.dist_entropy, interpolation='nearest')
# plt.colorbar()
class ColorizeImageCaffe(ColorizeImageBase):
def __init__(self, Xd=256):
print('ColorizeImageCaffe instantiated')
ColorizeImageBase.__init__(self, Xd)
self.l_norm = 1.
self.ab_norm = 1.
self.l_mean = 50.
self.ab_mean = 0.
self.mask_mult = 110.
self.pred_ab_layer = 'pred_ab' # predicted ab layer
# Load grid properties
self.pts_in_hull_path = './data/color_bins/pts_in_hull.npy'
self.pts_in_hull = np.load(self.pts_in_hull_path) # 313x2, in-gamut
# ***** Net preparation *****
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path=''):
import caffe
print('gpu_id = %d, net_path = %s, model_path = %s' % (gpu_id, prototxt_path, caffemodel_path))
if gpu_id == -1:
caffe.set_mode_cpu()
else:
caffe.set_device(gpu_id)
caffe.set_mode_gpu()
self.gpu_id = gpu_id
self.net = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST)
self.net_set = True
# automatically set cluster centers
if len(self.net.params[self.pred_ab_layer][0].data[...].shape) == 4 and self.net.params[self.pred_ab_layer][0].data[...].shape[1] == 313:
print('Setting ab cluster centers in layer: %s' % self.pred_ab_layer)
self.net.params[self.pred_ab_layer][0].data[:, :, 0, 0] = self.pts_in_hull.T
# automatically set upsampling kernel
for layer in self.net._layer_names:
if layer[-3:] == '_us':
print('Setting upsampling layer kernel: %s' % layer)
self.net.params[layer][0].data[:, 0, :, :] = np.array(((.25, .5, .25, 0), (.5, 1., .5, 0), (.25, .5, .25, 0), (0, 0, 0, 0)))[np.newaxis, :, :]
# ***** Call forward *****
def net_forward(self, input_ab, input_mask):
# INPUTS
# ab 2xXxX input color patches (non-normalized)
# mask 1xXxX input mask, indicating which points have been provided
# assumes self.img_l_mc has been set
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
return -1
net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
self.net.forward()
# return prediction
self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
self._set_out_ab_()
return self.output_rgb
def get_img_forward(self):
# get image with point estimate
return self.output_rgb
def get_img_gray(self):
# Get black and white image
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
class ColorizeImageCaffeGlobDist(ColorizeImageCaffe):
# Caffe colorization, with additional global histogram as input
def __init__(self, Xd=256):
ColorizeImageCaffe.__init__(self, Xd)
self.glob_mask_mult = 1.
self.glob_layer = 'glob_ab_313_mask'
def net_forward(self, input_ab, input_mask, glob_dist=-1):
# glob_dist is 313 array, or -1
if np.array(glob_dist).flatten()[0] == -1: # run without this, zero it out
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = 0.
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = 0.
else: # run conditioned on global histogram
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = glob_dist
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = self.glob_mask_mult
self.output_rgb = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
self._set_out_ab_()
return self.output_rgb
class ColorizeImageCaffeDist(ColorizeImageCaffe):
# caffe model which includes distribution prediction
def __init__(self, Xd=256):
ColorizeImageCaffe.__init__(self, Xd)
self.dist_ab_set = False
self.scale_S_layer = 'scale_S'
self.dist_ab_S_layer = 'dist_ab_S' # softened distribution layer
self.pts_grid = np.load('./data/color_bins/pts_grid.npy') # 529x2, all points
self.in_hull = np.load('./data/color_bins/in_hull.npy') # 529 bool
self.AB = self.pts_grid.shape[0] # 529
self.A = int(np.sqrt(self.AB)) # 23
self.B = int(np.sqrt(self.AB)) # 23
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
self.dist_entropy = np.zeros((self.Xd, self.Xd))
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path='', S=.2):
ColorizeImageCaffe.prep_net(self, gpu_id, prototxt_path=prototxt_path, caffemodel_path=caffemodel_path)
self.S = S
self.net.params[self.scale_S_layer][0].data[...] = S
def net_forward(self, input_ab, input_mask):
# INPUTS
# ab 2xXxX input color patches (non-normalized)
# mask 1xXxX input mask, indicating which points have been provided
# assumes self.img_l_mc has been set
function_return = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
if np.array(function_return).flatten()[0] == -1: # errored out
return -1
# set distribution
# in-gamut, CxXxX, C = 313
self.dist_ab = self.net.blobs[self.dist_ab_S_layer].data[0, :, :, :]
self.dist_ab_set = True
# full grid, ABxXxX, AB = 529
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
# gridded, AxBxXxX, A = 23
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
# return
return function_return
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
# ''' Recommended colors at point (h,w)
# Call this after calling net_forward
# '''
# if not self.dist_ab_set:
# print('Need to set prediction first')
# return 0
#
# # randomly sample from pdf
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
# cmf = cmf / cmf[-1]
# cmf_bins = cmf
#
# # randomly sample N points
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
# inds = np.digitize(rnd_pts, bins=cmf_bins)
# rnd_pts_ab = self.pts_in_hull[inds, :]
#
# # run k-means
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
#
# # sort by cluster occupancy
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
#
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
#
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
# if return_conf:
# return cluster_centers, cluster_per
# else:
# return cluster_centers
def compute_entropy(self):
# compute the distribution entropy (really slow right now)
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
# def plot_dist_grid(self, h, w):
# Plots distribution at a given point
# plt.figure()
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
# plt.colorbar()
# plt.ylabel('a')
# plt.xlabel('b')
# def plot_dist_entropy(self):
# Plots distribution at a given point
# plt.figure()
# plt.imshow(-self.dist_entropy, interpolation='nearest')
# plt.colorbar()