mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
566 lines
22 KiB
Python
566 lines
22 KiB
Python
import numpy as np
|
|
# import matplotlib.pyplot as plt
|
|
# from skimage import color
|
|
# from sklearn.cluster import KMeans
|
|
import os
|
|
import cv2
|
|
from scipy.ndimage.interpolation import zoom
|
|
|
|
def create_temp_directory(path_template, N=1e8):
|
|
print(path_template)
|
|
cur_path = path_template % np.random.randint(0, N)
|
|
while(os.path.exists(cur_path)):
|
|
cur_path = path_template % np.random.randint(0, N)
|
|
print('Creating directory: %s' % cur_path)
|
|
os.mkdir(cur_path)
|
|
return cur_path
|
|
|
|
|
|
def lab2rgb_transpose(img_l, img_ab):
|
|
''' INPUTS
|
|
img_l 1xXxX [0,100]
|
|
img_ab 2xXxX [-100,100]
|
|
OUTPUTS
|
|
returned value is XxXx3 '''
|
|
pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0))
|
|
# im = color.lab2rgb(pred_lab)
|
|
im = cv2.cvtColor(pred_lab.astype('float32'),cv2.COLOR_LAB2RGB)
|
|
pred_rgb = (np.clip(im, 0, 1) * 255).astype('uint8')
|
|
return pred_rgb
|
|
|
|
|
|
def rgb2lab_transpose(img_rgb):
|
|
''' INPUTS
|
|
img_rgb XxXx3
|
|
OUTPUTS
|
|
returned value is 3xXxX '''
|
|
# im=color.rgb2lab(img_rgb)
|
|
im = cv2.cvtColor(img_rgb.astype(np.float32)/255, cv2.COLOR_RGB2LAB)
|
|
return im.transpose((2, 0, 1))
|
|
|
|
|
|
class ColorizeImageBase():
|
|
def __init__(self, Xd=256, Xfullres_max=10000):
|
|
self.Xd = Xd
|
|
self.img_l_set = False
|
|
self.net_set = False
|
|
self.Xfullres_max = Xfullres_max # maximum size of maximum dimension
|
|
self.img_just_set = False # this will be true whenever image is just loaded
|
|
# net_forward can set this to False if they want
|
|
|
|
def prep_net(self):
|
|
raise Exception("Should be implemented by base class")
|
|
|
|
# ***** Image prepping *****
|
|
def load_image(self, im):
|
|
# rgb image [CxXdxXd]
|
|
self.img_rgb_fullres = im.copy()
|
|
self._set_img_lab_fullres_()
|
|
|
|
im = cv2.resize(im, (self.Xd, self.Xd))
|
|
self.img_rgb = im.copy()
|
|
# self.img_rgb = sp.misc.imresize(plt.imread(input_path),(self.Xd,self.Xd)).transpose((2,0,1))
|
|
|
|
self.img_l_set = True
|
|
|
|
# convert into lab space
|
|
self._set_img_lab_()
|
|
self._set_img_lab_mc_()
|
|
|
|
def set_image(self, input_image):
|
|
self.img_rgb_fullres = input_image.copy()
|
|
self._set_img_lab_fullres_()
|
|
|
|
self.img_l_set = True
|
|
|
|
self.img_rgb = input_image
|
|
# convert into lab space
|
|
self._set_img_lab_()
|
|
self._set_img_lab_mc_()
|
|
|
|
def net_forward(self, input_ab, input_mask):
|
|
# INPUTS
|
|
# ab 2xXxX input color patches (non-normalized)
|
|
# mask 1xXxX input mask, indicating which points have been provided
|
|
# assumes self.img_l_mc has been set
|
|
|
|
if(not self.img_l_set):
|
|
print('I need to have an image!')
|
|
return -1
|
|
if(not self.net_set):
|
|
print('I need to have a net!')
|
|
return -1
|
|
|
|
self.input_ab = input_ab
|
|
self.input_ab_mc = (input_ab - self.ab_mean) / self.ab_norm
|
|
self.input_mask = input_mask
|
|
self.input_mask_mult = input_mask * self.mask_mult
|
|
return 0
|
|
|
|
def get_result_PSNR(self, result=-1, return_SE_map=False):
|
|
if np.array((result)).flatten()[0] == -1:
|
|
cur_result = self.get_img_forward()
|
|
else:
|
|
cur_result = result.copy()
|
|
SE_map = (1. * self.img_rgb - cur_result)**2
|
|
cur_MSE = np.mean(SE_map)
|
|
cur_PSNR = 20 * np.log10(255. / np.sqrt(cur_MSE))
|
|
if return_SE_map:
|
|
return(cur_PSNR, SE_map)
|
|
else:
|
|
return cur_PSNR
|
|
|
|
def get_img_forward(self):
|
|
# get image with point estimate
|
|
return self.output_rgb
|
|
|
|
def get_img_gray(self):
|
|
# Get black and white image
|
|
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
|
|
|
def get_img_gray_fullres(self):
|
|
# Get black and white image
|
|
return lab2rgb_transpose(self.img_l_fullres, np.zeros((2, self.img_l_fullres.shape[1], self.img_l_fullres.shape[2])))
|
|
|
|
def get_img_fullres(self):
|
|
# This assumes self.img_l_fullres, self.output_ab are set.
|
|
# Typically, this means that set_image() and net_forward()
|
|
# have been called.
|
|
# bilinear upsample
|
|
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
|
|
output_ab_fullres = zoom(self.output_ab, zoom_factor, order=1)
|
|
|
|
return lab2rgb_transpose(self.img_l_fullres, output_ab_fullres)
|
|
|
|
def get_input_img_fullres(self):
|
|
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
|
|
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=1)
|
|
return lab2rgb_transpose(self.img_l_fullres, input_ab_fullres)
|
|
|
|
def get_input_img(self):
|
|
return lab2rgb_transpose(self.img_l, self.input_ab)
|
|
|
|
def get_img_mask(self):
|
|
# Get black and white image
|
|
return lab2rgb_transpose(100. * (1 - self.input_mask), np.zeros((2, self.Xd, self.Xd)))
|
|
|
|
def get_img_mask_fullres(self):
|
|
# Get black and white image
|
|
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
|
|
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
|
|
return lab2rgb_transpose(100. * (1 - input_mask_fullres), np.zeros((2, input_mask_fullres.shape[1], input_mask_fullres.shape[2])))
|
|
|
|
def get_sup_img(self):
|
|
return lab2rgb_transpose(50 * self.input_mask, self.input_ab)
|
|
|
|
def get_sup_fullres(self):
|
|
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
|
|
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
|
|
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=0)
|
|
return lab2rgb_transpose(50 * input_mask_fullres, input_ab_fullres)
|
|
|
|
# ***** Private functions *****
|
|
def _set_img_lab_fullres_(self):
|
|
# adjust full resolution image to be within maximum dimension is within Xfullres_max
|
|
Xfullres = self.img_rgb_fullres.shape[0]
|
|
Yfullres = self.img_rgb_fullres.shape[1]
|
|
if Xfullres > self.Xfullres_max or Yfullres > self.Xfullres_max:
|
|
if Xfullres > Yfullres:
|
|
zoom_factor = 1. * self.Xfullres_max / Xfullres
|
|
else:
|
|
zoom_factor = 1. * self.Xfullres_max / Yfullres
|
|
self.img_rgb_fullres = zoom(self.img_rgb_fullres, (zoom_factor, zoom_factor, 1), order=1)
|
|
|
|
self.img_lab_fullres = cv2.cvtColor(self.img_rgb_fullres.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
|
|
# self.img_lab_fullres = color.rgb2lab(self.img_rgb_fullres).transpose((2, 0, 1))
|
|
self.img_l_fullres = self.img_lab_fullres[[0], :, :]
|
|
self.img_ab_fullres = self.img_lab_fullres[1:, :, :]
|
|
|
|
def _set_img_lab_(self):
|
|
# set self.img_lab from self.im_rgb
|
|
self.img_lab = cv2.cvtColor(self.img_rgb.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
|
|
# self.img_lab = color.rgb2lab(self.img_rgb).transpose((2, 0, 1))
|
|
self.img_l = self.img_lab[[0], :, :]
|
|
self.img_ab = self.img_lab[1:, :, :]
|
|
|
|
def _set_img_lab_mc_(self):
|
|
# set self.img_lab_mc from self.img_lab
|
|
# lab image, mean centered [XxYxX]
|
|
self.img_lab_mc = self.img_lab / np.array((self.l_norm, self.ab_norm, self.ab_norm))[:, np.newaxis, np.newaxis] - np.array(
|
|
(self.l_mean / self.l_norm, self.ab_mean / self.ab_norm, self.ab_mean / self.ab_norm))[:, np.newaxis, np.newaxis]
|
|
self._set_img_l_()
|
|
|
|
def _set_img_l_(self):
|
|
self.img_l_mc = self.img_lab_mc[[0], :, :]
|
|
self.img_l_set = True
|
|
|
|
def _set_img_ab_(self):
|
|
self.img_ab_mc = self.img_lab_mc[[1, 2], :, :]
|
|
|
|
def _set_out_ab_(self):
|
|
self.output_lab = rgb2lab_transpose(self.output_rgb)
|
|
self.output_ab = self.output_lab[1:, :, :]
|
|
|
|
|
|
class ColorizeImageTorch(ColorizeImageBase):
|
|
def __init__(self, Xd=256, maskcent=False):
|
|
print('ColorizeImageTorch instantiated')
|
|
ColorizeImageBase.__init__(self, Xd)
|
|
self.l_norm = 1.
|
|
self.ab_norm = 1.
|
|
self.l_mean = 50.
|
|
self.ab_mean = 0.
|
|
self.mask_mult = 1.
|
|
self.mask_cent = .5 if maskcent else 0
|
|
|
|
# Load grid properties
|
|
self.pts_in_hull = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
|
|
|
|
# ***** Net preparation *****
|
|
def prep_net(self, gpu_id=None, path='', dist=False):
|
|
import torch
|
|
import models.pytorch.model as model
|
|
print('path = %s' % path)
|
|
print('Model set! dist mode? ', dist)
|
|
self.net = model.SIGGRAPHGenerator(dist=dist)
|
|
state_dict = torch.load(path)
|
|
if hasattr(state_dict, '_metadata'):
|
|
del state_dict._metadata
|
|
|
|
# patch InstanceNorm checkpoints prior to 0.4
|
|
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
|
self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.'))
|
|
self.net.load_state_dict(state_dict)
|
|
if gpu_id != None:
|
|
self.net.cuda()
|
|
self.net.eval()
|
|
self.net_set = True
|
|
|
|
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
|
key = keys[i]
|
|
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
|
if module.__class__.__name__.startswith('InstanceNorm') and \
|
|
(key == 'running_mean' or key == 'running_var'):
|
|
if getattr(module, key) is None:
|
|
state_dict.pop('.'.join(keys))
|
|
if module.__class__.__name__.startswith('InstanceNorm') and \
|
|
(key == 'num_batches_tracked'):
|
|
state_dict.pop('.'.join(keys))
|
|
else:
|
|
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
|
|
|
# ***** Call forward *****
|
|
def net_forward(self, input_ab, input_mask, f):
|
|
# INPUTS
|
|
# ab 2xXxX input color patches (non-normalized)
|
|
# mask 1xXxX input mask, indicating which points have been provided
|
|
# assumes self.img_l_mc has been set
|
|
|
|
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
|
return -1
|
|
|
|
# net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
|
|
|
|
# return prediction
|
|
# self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
|
|
# embed()
|
|
output_ab = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent,f)[0, :, :, :].cpu().data.numpy()
|
|
self.output_rgb = lab2rgb_transpose(self.img_l, output_ab)
|
|
# self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
|
|
|
|
self._set_out_ab_()
|
|
return self.output_rgb
|
|
|
|
def get_img_forward(self):
|
|
# get image with point estimate
|
|
return self.output_rgb
|
|
|
|
def get_img_gray(self):
|
|
# Get black and white image
|
|
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
|
|
|
|
|
class ColorizeImageTorchDist(ColorizeImageTorch):
|
|
def __init__(self, Xd=256, maskcent=False):
|
|
ColorizeImageTorch.__init__(self, Xd)
|
|
self.dist_ab_set = False
|
|
self.pts_grid = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
|
|
self.in_hull = np.ones(529, dtype=bool)
|
|
self.AB = self.pts_grid.shape[0] # 529
|
|
self.A = int(np.sqrt(self.AB)) # 23
|
|
self.B = int(np.sqrt(self.AB)) # 23
|
|
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
|
|
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
|
|
self.dist_entropy = np.zeros((self.Xd, self.Xd))
|
|
self.mask_cent = .5 if maskcent else 0
|
|
|
|
def prep_net(self, gpu_id=None, path='', dist=True, S=.2):
|
|
ColorizeImageTorch.prep_net(self, gpu_id=gpu_id, path=path, dist=dist)
|
|
# set S somehow
|
|
|
|
def net_forward(self, input_ab, input_mask):
|
|
# INPUTS
|
|
# ab 2xXxX input color patches (non-normalized)
|
|
# mask 1xXxX input mask, indicating which points have been provided
|
|
# assumes self.img_l_mc has been set
|
|
|
|
# embed()
|
|
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
|
return -1
|
|
|
|
# set distribution
|
|
(function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)
|
|
function_return = function_return[0, :, :, :].cpu().data.numpy()
|
|
self.dist_ab = self.dist_ab[0, :, :, :].cpu().data.numpy()
|
|
self.dist_ab_set = True
|
|
|
|
# full grid, ABxXxX, AB = 529
|
|
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
|
|
|
|
# gridded, AxBxXxX, A = 23
|
|
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
|
|
|
|
# return
|
|
return function_return
|
|
|
|
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
|
|
# ''' Recommended colors at point (h,w)
|
|
# Call this after calling net_forward
|
|
# '''
|
|
# if not self.dist_ab_set:
|
|
# print('Need to set prediction first')
|
|
# return 0
|
|
#
|
|
# # randomly sample from pdf
|
|
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
|
|
# cmf = cmf / cmf[-1]
|
|
# cmf_bins = cmf
|
|
#
|
|
# # randomly sample N points
|
|
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
|
|
# inds = np.digitize(rnd_pts, bins=cmf_bins)
|
|
# rnd_pts_ab = self.pts_in_hull[inds, :]
|
|
#
|
|
# # run k-means
|
|
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
|
|
#
|
|
# # sort by cluster occupancy
|
|
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
|
|
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
|
|
#
|
|
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
|
|
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
|
|
#
|
|
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
|
|
# if return_conf:
|
|
# return cluster_centers, cluster_per
|
|
# else:
|
|
# return cluster_centers
|
|
|
|
def compute_entropy(self):
|
|
# compute the distribution entropy (really slow right now)
|
|
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
|
|
|
|
# def plot_dist_grid(self, h, w):
|
|
# # Plots distribution at a given point
|
|
# plt.figure()
|
|
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
|
|
# plt.colorbar()
|
|
# plt.ylabel('a')
|
|
# plt.xlabel('b')
|
|
|
|
# def plot_dist_entropy(self):
|
|
# # Plots distribution at a given point
|
|
# plt.figure()
|
|
# plt.imshow(-self.dist_entropy, interpolation='nearest')
|
|
# plt.colorbar()
|
|
|
|
|
|
class ColorizeImageCaffe(ColorizeImageBase):
|
|
def __init__(self, Xd=256):
|
|
print('ColorizeImageCaffe instantiated')
|
|
ColorizeImageBase.__init__(self, Xd)
|
|
self.l_norm = 1.
|
|
self.ab_norm = 1.
|
|
self.l_mean = 50.
|
|
self.ab_mean = 0.
|
|
self.mask_mult = 110.
|
|
|
|
self.pred_ab_layer = 'pred_ab' # predicted ab layer
|
|
|
|
# Load grid properties
|
|
self.pts_in_hull_path = './data/color_bins/pts_in_hull.npy'
|
|
self.pts_in_hull = np.load(self.pts_in_hull_path) # 313x2, in-gamut
|
|
|
|
# ***** Net preparation *****
|
|
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path=''):
|
|
import caffe
|
|
print('gpu_id = %d, net_path = %s, model_path = %s' % (gpu_id, prototxt_path, caffemodel_path))
|
|
if gpu_id == -1:
|
|
caffe.set_mode_cpu()
|
|
else:
|
|
caffe.set_device(gpu_id)
|
|
caffe.set_mode_gpu()
|
|
self.gpu_id = gpu_id
|
|
self.net = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST)
|
|
self.net_set = True
|
|
|
|
# automatically set cluster centers
|
|
if len(self.net.params[self.pred_ab_layer][0].data[...].shape) == 4 and self.net.params[self.pred_ab_layer][0].data[...].shape[1] == 313:
|
|
print('Setting ab cluster centers in layer: %s' % self.pred_ab_layer)
|
|
self.net.params[self.pred_ab_layer][0].data[:, :, 0, 0] = self.pts_in_hull.T
|
|
|
|
# automatically set upsampling kernel
|
|
for layer in self.net._layer_names:
|
|
if layer[-3:] == '_us':
|
|
print('Setting upsampling layer kernel: %s' % layer)
|
|
self.net.params[layer][0].data[:, 0, :, :] = np.array(((.25, .5, .25, 0), (.5, 1., .5, 0), (.25, .5, .25, 0), (0, 0, 0, 0)))[np.newaxis, :, :]
|
|
|
|
# ***** Call forward *****
|
|
def net_forward(self, input_ab, input_mask):
|
|
# INPUTS
|
|
# ab 2xXxX input color patches (non-normalized)
|
|
# mask 1xXxX input mask, indicating which points have been provided
|
|
# assumes self.img_l_mc has been set
|
|
|
|
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
|
return -1
|
|
|
|
net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
|
|
|
|
self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
|
|
self.net.forward()
|
|
|
|
# return prediction
|
|
self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
|
|
|
|
self._set_out_ab_()
|
|
return self.output_rgb
|
|
|
|
def get_img_forward(self):
|
|
# get image with point estimate
|
|
return self.output_rgb
|
|
|
|
def get_img_gray(self):
|
|
# Get black and white image
|
|
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
|
|
|
|
|
class ColorizeImageCaffeGlobDist(ColorizeImageCaffe):
|
|
# Caffe colorization, with additional global histogram as input
|
|
def __init__(self, Xd=256):
|
|
ColorizeImageCaffe.__init__(self, Xd)
|
|
self.glob_mask_mult = 1.
|
|
self.glob_layer = 'glob_ab_313_mask'
|
|
|
|
def net_forward(self, input_ab, input_mask, glob_dist=-1):
|
|
# glob_dist is 313 array, or -1
|
|
if np.array(glob_dist).flatten()[0] == -1: # run without this, zero it out
|
|
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = 0.
|
|
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = 0.
|
|
else: # run conditioned on global histogram
|
|
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = glob_dist
|
|
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = self.glob_mask_mult
|
|
|
|
self.output_rgb = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
|
|
self._set_out_ab_()
|
|
return self.output_rgb
|
|
|
|
|
|
class ColorizeImageCaffeDist(ColorizeImageCaffe):
|
|
# caffe model which includes distribution prediction
|
|
def __init__(self, Xd=256):
|
|
ColorizeImageCaffe.__init__(self, Xd)
|
|
self.dist_ab_set = False
|
|
self.scale_S_layer = 'scale_S'
|
|
self.dist_ab_S_layer = 'dist_ab_S' # softened distribution layer
|
|
self.pts_grid = np.load('./data/color_bins/pts_grid.npy') # 529x2, all points
|
|
self.in_hull = np.load('./data/color_bins/in_hull.npy') # 529 bool
|
|
self.AB = self.pts_grid.shape[0] # 529
|
|
self.A = int(np.sqrt(self.AB)) # 23
|
|
self.B = int(np.sqrt(self.AB)) # 23
|
|
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
|
|
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
|
|
self.dist_entropy = np.zeros((self.Xd, self.Xd))
|
|
|
|
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path='', S=.2):
|
|
ColorizeImageCaffe.prep_net(self, gpu_id, prototxt_path=prototxt_path, caffemodel_path=caffemodel_path)
|
|
self.S = S
|
|
self.net.params[self.scale_S_layer][0].data[...] = S
|
|
|
|
def net_forward(self, input_ab, input_mask):
|
|
# INPUTS
|
|
# ab 2xXxX input color patches (non-normalized)
|
|
# mask 1xXxX input mask, indicating which points have been provided
|
|
# assumes self.img_l_mc has been set
|
|
|
|
function_return = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
|
|
if np.array(function_return).flatten()[0] == -1: # errored out
|
|
return -1
|
|
|
|
# set distribution
|
|
# in-gamut, CxXxX, C = 313
|
|
self.dist_ab = self.net.blobs[self.dist_ab_S_layer].data[0, :, :, :]
|
|
self.dist_ab_set = True
|
|
|
|
# full grid, ABxXxX, AB = 529
|
|
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
|
|
|
|
# gridded, AxBxXxX, A = 23
|
|
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
|
|
|
|
# return
|
|
return function_return
|
|
|
|
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
|
|
# ''' Recommended colors at point (h,w)
|
|
# Call this after calling net_forward
|
|
# '''
|
|
# if not self.dist_ab_set:
|
|
# print('Need to set prediction first')
|
|
# return 0
|
|
#
|
|
# # randomly sample from pdf
|
|
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
|
|
# cmf = cmf / cmf[-1]
|
|
# cmf_bins = cmf
|
|
#
|
|
# # randomly sample N points
|
|
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
|
|
# inds = np.digitize(rnd_pts, bins=cmf_bins)
|
|
# rnd_pts_ab = self.pts_in_hull[inds, :]
|
|
#
|
|
# # run k-means
|
|
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
|
|
#
|
|
# # sort by cluster occupancy
|
|
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
|
|
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
|
|
#
|
|
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
|
|
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
|
|
#
|
|
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
|
|
# if return_conf:
|
|
# return cluster_centers, cluster_per
|
|
# else:
|
|
# return cluster_centers
|
|
|
|
def compute_entropy(self):
|
|
# compute the distribution entropy (really slow right now)
|
|
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
|
|
|
|
# def plot_dist_grid(self, h, w):
|
|
# Plots distribution at a given point
|
|
# plt.figure()
|
|
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
|
|
# plt.colorbar()
|
|
# plt.ylabel('a')
|
|
# plt.xlabel('b')
|
|
|
|
# def plot_dist_entropy(self):
|
|
# Plots distribution at a given point
|
|
# plt.figure()
|
|
# plt.imshow(-self.dist_entropy, interpolation='nearest')
|
|
# plt.colorbar()
|