mirror of
https://github.com/kritiksoman/GIMP-ML
synced 2024-10-31 09:20:18 +00:00
julyUpdate
This commit is contained in:
parent
17f43e3c4e
commit
fdf3b77842
50
gimp-plugins/MiDaS/LICENSE
Normal file
50
gimp-plugins/MiDaS/LICENSE
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Virginia Tech Vision and Learning Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
------------------ LICENSE FOR MiDaS --------------------
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
--------------------------- LICENSE FOR EdgeConnect --------------------------------
|
||||
|
||||
Attribution-NonCommercial 4.0 International
|
192
gimp-plugins/MiDaS/MiDaS_utils.py
Normal file
192
gimp-plugins/MiDaS/MiDaS_utils.py
Normal file
@ -0,0 +1,192 @@
|
||||
"""Utils for monoDepth.
|
||||
"""
|
||||
import sys
|
||||
import re
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
# import imageio
|
||||
|
||||
|
||||
def read_pfm(path):
|
||||
"""Read pfm file.
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
tuple: (data, scale)
|
||||
"""
|
||||
with open(path, "rb") as file:
|
||||
|
||||
color = None
|
||||
width = None
|
||||
height = None
|
||||
scale = None
|
||||
endian = None
|
||||
|
||||
header = file.readline().rstrip()
|
||||
if header.decode("ascii") == "PF":
|
||||
color = True
|
||||
elif header.decode("ascii") == "Pf":
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Not a PFM file: " + path)
|
||||
|
||||
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
||||
if dim_match:
|
||||
width, height = list(map(int, dim_match.groups()))
|
||||
else:
|
||||
raise Exception("Malformed PFM header.")
|
||||
|
||||
scale = float(file.readline().decode("ascii").rstrip())
|
||||
if scale < 0:
|
||||
# little-endian
|
||||
endian = "<"
|
||||
scale = -scale
|
||||
else:
|
||||
# big-endian
|
||||
endian = ">"
|
||||
|
||||
data = np.fromfile(file, endian + "f")
|
||||
shape = (height, width, 3) if color else (height, width)
|
||||
|
||||
data = np.reshape(data, shape)
|
||||
data = np.flipud(data)
|
||||
|
||||
return data, scale
|
||||
|
||||
|
||||
def write_pfm(path, image, scale=1):
|
||||
"""Write pfm file.
|
||||
|
||||
Args:
|
||||
path (str): pathto file
|
||||
image (array): data
|
||||
scale (int, optional): Scale. Defaults to 1.
|
||||
"""
|
||||
|
||||
with open(path, "wb") as file:
|
||||
color = None
|
||||
|
||||
if image.dtype.name != "float32":
|
||||
raise Exception("Image dtype must be float32.")
|
||||
|
||||
image = np.flipud(image)
|
||||
|
||||
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
||||
color = True
|
||||
elif (
|
||||
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
||||
): # greyscale
|
||||
color = False
|
||||
else:
|
||||
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
||||
|
||||
file.write("PF\n" if color else "Pf\n".encode())
|
||||
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
||||
|
||||
endian = image.dtype.byteorder
|
||||
|
||||
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
||||
scale = -scale
|
||||
|
||||
file.write("%f\n".encode() % scale)
|
||||
|
||||
image.tofile(file)
|
||||
|
||||
|
||||
def read_image(path):
|
||||
"""Read image and output RGB image (0-1).
|
||||
|
||||
Args:
|
||||
path (str): path to file
|
||||
|
||||
Returns:
|
||||
array: RGB image (0-1)
|
||||
"""
|
||||
img = cv2.imread(path)
|
||||
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def resize_image(img):
|
||||
"""Resize image and make it fit for network.
|
||||
|
||||
Args:
|
||||
img (array): image
|
||||
|
||||
Returns:
|
||||
tensor: data ready for network
|
||||
"""
|
||||
height_orig = img.shape[0]
|
||||
width_orig = img.shape[1]
|
||||
unit_scale = 384.
|
||||
|
||||
if width_orig > height_orig:
|
||||
scale = width_orig / unit_scale
|
||||
else:
|
||||
scale = height_orig / unit_scale
|
||||
|
||||
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
||||
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
||||
|
||||
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
img_resized = (
|
||||
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
||||
)
|
||||
img_resized = img_resized.unsqueeze(0)
|
||||
|
||||
return img_resized
|
||||
|
||||
|
||||
def resize_depth(depth, width, height):
|
||||
"""Resize depth map and bring to CPU (numpy).
|
||||
|
||||
Args:
|
||||
depth (tensor): depth
|
||||
width (int): image width
|
||||
height (int): image height
|
||||
|
||||
Returns:
|
||||
array: processed depth
|
||||
"""
|
||||
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
||||
depth = cv2.blur(depth.numpy(), (3, 3))
|
||||
depth_resized = cv2.resize(
|
||||
depth, (width, height), interpolation=cv2.INTER_AREA
|
||||
)
|
||||
|
||||
return depth_resized
|
||||
|
||||
def write_depth(path, depth, bits=1):
|
||||
"""Write depth map to pfm and png file.
|
||||
|
||||
Args:
|
||||
path (str): filepath without extension
|
||||
depth (array): depth
|
||||
"""
|
||||
# write_pfm(path + ".pfm", depth.astype(np.float32))
|
||||
|
||||
depth_min = depth.min()
|
||||
depth_max = depth.max()
|
||||
|
||||
max_val = (2**(8*bits))-1
|
||||
|
||||
if depth_max - depth_min > np.finfo("float").eps:
|
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||
else:
|
||||
out = 0
|
||||
|
||||
if bits == 1:
|
||||
cv2.imwrite(path + ".png", out.astype("uint8"))
|
||||
elif bits == 2:
|
||||
cv2.imwrite(path + ".png", out.astype("uint16"))
|
||||
|
||||
return
|
BIN
gimp-plugins/MiDaS/MiDaS_utils.pyc
Normal file
BIN
gimp-plugins/MiDaS/MiDaS_utils.pyc
Normal file
Binary file not shown.
0
gimp-plugins/MiDaS/__init__.py
Normal file
0
gimp-plugins/MiDaS/__init__.py
Normal file
BIN
gimp-plugins/MiDaS/__init__.pyc
Normal file
BIN
gimp-plugins/MiDaS/__init__.pyc
Normal file
Binary file not shown.
186
gimp-plugins/MiDaS/monodepth_net.py
Normal file
186
gimp-plugins/MiDaS/monodepth_net.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""MonoDepthNet: Network for monocular depth estimation trained by mixing several datasets.
|
||||
This file contains code that is adapted from
|
||||
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
|
||||
class MonoDepthNet(nn.Module):
|
||||
"""Network for monocular depth estimation.
|
||||
"""
|
||||
|
||||
def __init__(self, path=None, features=256):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
path (str, optional): Path to saved model. Defaults to None.
|
||||
features (int, optional): Number of features. Defaults to 256.
|
||||
"""
|
||||
super(MonoDepthNet,self).__init__()
|
||||
|
||||
resnet = models.resnet50(pretrained=False)
|
||||
|
||||
self.pretrained = nn.Module()
|
||||
self.scratch = nn.Module()
|
||||
self.pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
|
||||
resnet.maxpool, resnet.layer1)
|
||||
|
||||
self.pretrained.layer2 = resnet.layer2
|
||||
self.pretrained.layer3 = resnet.layer3
|
||||
self.pretrained.layer4 = resnet.layer4
|
||||
|
||||
# adjust channel number of feature maps
|
||||
self.scratch.layer1_rn = nn.Conv2d(256, features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.scratch.layer2_rn = nn.Conv2d(512, features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.scratch.layer3_rn = nn.Conv2d(1024, features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.scratch.layer4_rn = nn.Conv2d(2048, features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
|
||||
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
||||
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
||||
|
||||
# adaptive output module: 2 convolutions and upsampling
|
||||
self.scratch.output_conv = nn.Sequential(nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
||||
nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode='bilinear'))
|
||||
|
||||
# load model
|
||||
if path:
|
||||
self.load(path)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input data (image)
|
||||
|
||||
Returns:
|
||||
tensor: depth
|
||||
"""
|
||||
layer_1 = self.pretrained.layer1(x)
|
||||
layer_2 = self.pretrained.layer2(layer_1)
|
||||
layer_3 = self.pretrained.layer3(layer_2)
|
||||
layer_4 = self.pretrained.layer4(layer_3)
|
||||
|
||||
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
||||
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
||||
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
||||
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
||||
|
||||
path_4 = self.scratch.refinenet4(layer_4_rn)
|
||||
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
||||
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
||||
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
||||
|
||||
out = self.scratch.output_conv(path_1)
|
||||
|
||||
return out
|
||||
|
||||
def load(self, path):
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
parameters = torch.load(path)
|
||||
|
||||
self.load_state_dict(parameters)
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
"""Interpolation module.
|
||||
"""
|
||||
|
||||
def __init__(self, scale_factor, mode):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
scale_factor (float): scaling
|
||||
mode (str): interpolation mode
|
||||
"""
|
||||
super(Interpolate, self).__init__()
|
||||
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: interpolated data
|
||||
"""
|
||||
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
"""Residual convolution module.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(ResidualConvUnit,self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (tensor): input
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
out = self.relu(x)
|
||||
out = self.conv1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""Feature fusion block.
|
||||
"""
|
||||
|
||||
def __init__(self, features):
|
||||
"""Init.
|
||||
|
||||
Args:
|
||||
features (int): number of features
|
||||
"""
|
||||
super(FeatureFusionBlock,self).__init__()
|
||||
|
||||
self.resConfUnit = ResidualConvUnit(features)
|
||||
|
||||
def forward(self, *xs):
|
||||
"""Forward pass.
|
||||
|
||||
Returns:
|
||||
tensor: output
|
||||
"""
|
||||
output = xs[0]
|
||||
|
||||
if len(xs) == 2:
|
||||
output += self.resConfUnit(xs[1])
|
||||
|
||||
output = self.resConfUnit(output)
|
||||
output = nn.functional.interpolate(output, scale_factor=2,
|
||||
mode='bilinear', align_corners=True)
|
||||
|
||||
return output
|
BIN
gimp-plugins/MiDaS/monodepth_net.pyc
Normal file
BIN
gimp-plugins/MiDaS/monodepth_net.pyc
Normal file
Binary file not shown.
78
gimp-plugins/MiDaS/run.py
Normal file
78
gimp-plugins/MiDaS/run.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Compute depth maps for images in the input folder.
|
||||
"""
|
||||
# import os
|
||||
# import glob
|
||||
import torch
|
||||
# from monodepth_net import MonoDepthNet
|
||||
# import utils
|
||||
# import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import cv2
|
||||
# import imageio
|
||||
|
||||
|
||||
def run_depth(img, model_path, Net, utils, target_w=None):
|
||||
"""Run MonoDepthNN to compute depth maps.
|
||||
|
||||
Args:
|
||||
input_path (str): path to input folder
|
||||
output_path (str): path to output folder
|
||||
model_path (str): path to saved model
|
||||
"""
|
||||
# print("initialize")
|
||||
|
||||
# select device
|
||||
device = torch.device("cpu")
|
||||
# print("device: %s" % device)
|
||||
|
||||
# load network
|
||||
model = Net(model_path)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# get input
|
||||
# img_names = glob.glob(os.path.join(input_path, "*"))
|
||||
# num_images = len(img_names)
|
||||
|
||||
# create output folder
|
||||
# os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# print("start processing")
|
||||
|
||||
# for ind, img_name in enumerate(img_names):
|
||||
|
||||
# print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
|
||||
|
||||
# input
|
||||
# img = utils.read_image(img_name)
|
||||
w = img.shape[1]
|
||||
scale = 640. / max(img.shape[0], img.shape[1])
|
||||
target_height, target_width = int(round(img.shape[0] * scale)), int(round(img.shape[1] * scale))
|
||||
img_input = utils.resize_image(img)
|
||||
# print(img_input.shape)
|
||||
img_input = img_input.to(device)
|
||||
# compute
|
||||
with torch.no_grad():
|
||||
out = model.forward(img_input)
|
||||
|
||||
depth = utils.resize_depth(out, target_width, target_height)
|
||||
img = cv2.resize((img * 255).astype(np.uint8), (target_width, target_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
|
||||
# np.save(filename + '.npy', depth)
|
||||
# utils.write_depth(filename, depth, bits=2)
|
||||
depth_min = depth.min()
|
||||
depth_max = depth.max()
|
||||
bits = 1
|
||||
max_val = (2 ** (8 * bits)) - 1
|
||||
|
||||
if depth_max - depth_min > np.finfo("float").eps:
|
||||
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
||||
else:
|
||||
out = 0
|
||||
out = out.astype("uint8")
|
||||
# cv2.imwrite("out.png", out)
|
||||
return out
|
||||
# print("finished")
|
||||
|
||||
|
BIN
gimp-plugins/MiDaS/run.pyc
Normal file
BIN
gimp-plugins/MiDaS/run.pyc
Normal file
Binary file not shown.
BIN
gimp-plugins/color_palette.png
Normal file
BIN
gimp-plugins/color_palette.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
@ -1,116 +0,0 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
|
||||
|
||||
from gimpfu import *
|
||||
import sys
|
||||
|
||||
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'neural-colorization'])
|
||||
|
||||
|
||||
import torch
|
||||
from model import generator
|
||||
from torch.autograd import Variable
|
||||
from scipy.ndimage import zoom
|
||||
from PIL import Image
|
||||
from argparse import Namespace
|
||||
import numpy as np
|
||||
# from skimage.color import rgb2yuv,yuv2rgb
|
||||
import cv2
|
||||
|
||||
def getcolor(input_image):
|
||||
p = np.repeat(input_image, 3, axis=2)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
g_available=1
|
||||
else:
|
||||
g_available=-1
|
||||
|
||||
args=Namespace(model=baseLoc+'neural-colorization/model.pth',gpu=g_available)
|
||||
|
||||
G = generator()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
G=G.cuda()
|
||||
G.load_state_dict(torch.load(args.model))
|
||||
else:
|
||||
G.load_state_dict(torch.load(args.model,map_location=torch.device('cpu')))
|
||||
|
||||
p = p.astype(np.float32)
|
||||
p = p / 255
|
||||
img_yuv = cv2.cvtColor(p, cv2.COLOR_RGB2YUV)
|
||||
# img_yuv = rgb2yuv(p)
|
||||
H,W,_ = img_yuv.shape
|
||||
infimg = np.expand_dims(np.expand_dims(img_yuv[...,0], axis=0), axis=0)
|
||||
img_variable = Variable(torch.Tensor(infimg-0.5))
|
||||
if args.gpu>=0:
|
||||
img_variable=img_variable.cuda(args.gpu)
|
||||
res = G(img_variable)
|
||||
uv=res.cpu().detach().numpy()
|
||||
uv[:,0,:,:] *= 0.436
|
||||
uv[:,1,:,:] *= 0.615
|
||||
(_,_,H1,W1) = uv.shape
|
||||
uv = zoom(uv,(1,1,float(H)/H1,float(W)/W1))
|
||||
yuv = np.concatenate([infimg,uv],axis=1)[0]
|
||||
# rgb=yuv2rgb(yuv.transpose(1,2,0))
|
||||
# out=(rgb.clip(min=0,max=1)*255)[:,:,[0,1,2]]
|
||||
rgb = cv2.cvtColor(yuv.transpose(1, 2, 0)*255, cv2.COLOR_YUV2RGB)
|
||||
rgb = rgb.clip(min=0,max=255)
|
||||
out = rgb.astype(np.uint8)
|
||||
|
||||
return out
|
||||
|
||||
def channelData(layer):#convert gimp image to numpy
|
||||
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
|
||||
pixChars=region[:,:] # Take whole layer
|
||||
bpp=region.bpp
|
||||
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
|
||||
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
|
||||
|
||||
def createResultLayer(image,name,result):
|
||||
rlBytes=np.uint8(result).tobytes();
|
||||
rl=gimp.Layer(image,name,image.width,image.height,image.active_layer.type,100,NORMAL_MODE)
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
image.add_layer(rl,0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def genNewImg(name,layer_np):
|
||||
h,w,d=layer_np.shape
|
||||
img=pdb.gimp_image_new(w, h, RGB)
|
||||
display=pdb.gimp_display_new(img)
|
||||
|
||||
rlBytes=np.uint8(layer_np).tobytes();
|
||||
rl=gimp.Layer(img,name,img.width,img.height,RGB,100,NORMAL_MODE)
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
|
||||
pdb.gimp_image_insert_layer(img, rl, None, 0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def colorize(img, layer) :
|
||||
gimp.progress_init("Coloring " + layer.name + "...")
|
||||
|
||||
imgmat = channelData(layer)
|
||||
cpy=getcolor(imgmat)
|
||||
|
||||
genNewImg(layer.name+'_colored',cpy)
|
||||
|
||||
|
||||
|
||||
register(
|
||||
"colorize",
|
||||
"colorize",
|
||||
"Generate monocular disparity map based on deep learning.",
|
||||
"Kritik Soman",
|
||||
"Your",
|
||||
"2020",
|
||||
"colorize...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[ (PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
],
|
||||
[],
|
||||
colorize, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
main()
|
60
gimp-plugins/colorpalette.py
Executable file
60
gimp-plugins/colorpalette.py
Executable file
@ -0,0 +1,60 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
from gimpfu import *
|
||||
import sys
|
||||
sys.path.extend([baseLoc + 'gimpenv/lib/python2.7', baseLoc + 'gimpenv/lib/python2.7/site-packages',
|
||||
baseLoc + 'gimpenv/lib/python2.7/site-packages/setuptools'])
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def channelData(layer): # convert gimp image to numpy
|
||||
region = layer.get_pixel_rgn(0, 0, layer.width, layer.height)
|
||||
pixChars = region[:, :] # Take whole layer
|
||||
bpp = region.bpp
|
||||
return np.frombuffer(pixChars, dtype=np.uint8).reshape(layer.height, layer.width, bpp)
|
||||
|
||||
def createResultLayer(image, name, result):
|
||||
rlBytes = np.uint8(result).tobytes();
|
||||
rl = gimp.Layer(image, name, image.width, image.height, image.active_layer.type, 100, NORMAL_MODE)
|
||||
region = rl.get_pixel_rgn(0, 0, rl.width, rl.height, True)
|
||||
region[:, :] = rlBytes
|
||||
image.add_layer(rl, 0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def genNewImg(name, layer_np):
|
||||
h, w, d = layer_np.shape
|
||||
img = pdb.gimp_image_new(w, h, RGB)
|
||||
display = pdb.gimp_display_new(img)
|
||||
|
||||
rlBytes = np.uint8(layer_np).tobytes();
|
||||
rl = gimp.Layer(img, name, img.width, img.height, RGB, 100, NORMAL_MODE)
|
||||
region = rl.get_pixel_rgn(0, 0, rl.width, rl.height, True)
|
||||
region[:, :] = rlBytes
|
||||
|
||||
pdb.gimp_image_insert_layer(img, rl, None, 0)
|
||||
|
||||
gimp.displays_flush()
|
||||
|
||||
|
||||
def colorpalette(img, layer):
|
||||
cpy = cv2.cvtColor(cv2.imread(baseLoc+'color_palette.png'),cv2.COLOR_BGR2RGB)
|
||||
genNewImg('palette', cpy)
|
||||
|
||||
|
||||
register(
|
||||
"colorpalette",
|
||||
"colorpalette",
|
||||
"colorpalette.",
|
||||
"Kritik Soman",
|
||||
"Your",
|
||||
"2020",
|
||||
"colorpalette...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[(PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
],
|
||||
[],
|
||||
colorpalette, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
main()
|
85
gimp-plugins/deepcolor.py
Executable file
85
gimp-plugins/deepcolor.py
Executable file
@ -0,0 +1,85 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
from gimpfu import *
|
||||
import sys
|
||||
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'ideepcolor'])
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
from data import colorize_image as CI
|
||||
|
||||
def createResultLayer(image,name,result):
|
||||
rlBytes=np.uint8(result).tobytes();
|
||||
rl=gimp.Layer(image,name,image.width,image.height)
|
||||
# ,image.active_layer.type,100,NORMAL_MODE)
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
image.add_layer(rl,0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def channelData(layer):#convert gimp image to numpy
|
||||
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
|
||||
pixChars=region[:,:] # Take whole layer
|
||||
bpp=region.bpp
|
||||
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
|
||||
|
||||
|
||||
def deepcolor(tmp1, tmp2, ilayerimg,ilayerc) :
|
||||
layerimg = channelData(ilayerimg)
|
||||
layerc = channelData(ilayerc)
|
||||
|
||||
if ilayerimg.name == ilayerc.name: # if local color hints are not provided by user
|
||||
mask = np.zeros((1, 256, 256)) # giving no user points, so mask is all 0's
|
||||
input_ab = np.zeros((2, 256, 256)) # ab values of user points, default to 0 for no input
|
||||
else:
|
||||
if layerc.shape[2] == 3: # error
|
||||
pdb.gimp_message("Alpha channel missing in " + ilayerc.name + " !")
|
||||
return
|
||||
else:
|
||||
input_ab = cv2.cvtColor(layerc[:,:,0:3].astype(np.float32)/255, cv2.COLOR_RGB2LAB)
|
||||
mask = layerc[:,:,3]>0
|
||||
mask = mask.astype(np.uint8)
|
||||
input_ab = cv2.resize(input_ab,(256,256))
|
||||
mask = cv2.resize(mask, (256, 256))
|
||||
mask = mask[np.newaxis, :, :]
|
||||
input_ab = input_ab[:,:, 1:3].transpose((2, 0, 1))
|
||||
|
||||
if layerimg.shape[2] == 4: #remove alpha channel in image if present
|
||||
layerimg = layerimg[:,:,0:3]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gimp.progress_init("(Using GPU) Running deepcolor for " + ilayerimg.name + "...")
|
||||
gpu_id = 0
|
||||
else:
|
||||
gimp.progress_init("(Using CPU) Running deepcolor for " + ilayerimg.name + "...")
|
||||
gpu_id = None
|
||||
|
||||
colorModel = CI.ColorizeImageTorch(Xd=256)
|
||||
colorModel.prep_net(gpu_id, baseLoc + 'ideepcolor/models/pytorch/caffemodel.pth')
|
||||
colorModel.load_image(layerimg) # load an image
|
||||
|
||||
img_out = colorModel.net_forward(input_ab, mask) # run model, returns 256x256 image
|
||||
img_out_fullres = colorModel.get_img_fullres() # get image at full resolution
|
||||
|
||||
createResultLayer(tmp1, 'new_' + ilayerimg.name, img_out_fullres)
|
||||
|
||||
|
||||
|
||||
register(
|
||||
"deepcolor",
|
||||
"deepcolor",
|
||||
"Running deepcolor.",
|
||||
"Kritik Soman",
|
||||
"Your",
|
||||
"2020",
|
||||
"deepcolor...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[ (PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
(PF_LAYER, "drawinglayer", "Original Image:", None),
|
||||
(PF_LAYER, "drawinglayer", "Color Mask:", None),
|
||||
],
|
||||
[],
|
||||
deepcolor, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
main()
|
Binary file not shown.
Binary file not shown.
21
gimp-plugins/ideepcolor/LICENSE
Normal file
21
gimp-plugins/ideepcolor/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2017 Jun-Yan Zhu and Richard Zhang
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
0
gimp-plugins/ideepcolor/data/__init__.py
Normal file
0
gimp-plugins/ideepcolor/data/__init__.py
Normal file
BIN
gimp-plugins/ideepcolor/data/__init__.pyc
Normal file
BIN
gimp-plugins/ideepcolor/data/__init__.pyc
Normal file
Binary file not shown.
BIN
gimp-plugins/ideepcolor/data/color_bins/in_hull.npy
Normal file
BIN
gimp-plugins/ideepcolor/data/color_bins/in_hull.npy
Normal file
Binary file not shown.
BIN
gimp-plugins/ideepcolor/data/color_bins/pts_grid.npy
Normal file
BIN
gimp-plugins/ideepcolor/data/color_bins/pts_grid.npy
Normal file
Binary file not shown.
BIN
gimp-plugins/ideepcolor/data/color_bins/pts_in_hull.npy
Normal file
BIN
gimp-plugins/ideepcolor/data/color_bins/pts_in_hull.npy
Normal file
Binary file not shown.
565
gimp-plugins/ideepcolor/data/colorize_image.py
Normal file
565
gimp-plugins/ideepcolor/data/colorize_image.py
Normal file
@ -0,0 +1,565 @@
|
||||
import numpy as np
|
||||
# import matplotlib.pyplot as plt
|
||||
# from skimage import color
|
||||
# from sklearn.cluster import KMeans
|
||||
import os
|
||||
import cv2
|
||||
from scipy.ndimage.interpolation import zoom
|
||||
|
||||
def create_temp_directory(path_template, N=1e8):
|
||||
print(path_template)
|
||||
cur_path = path_template % np.random.randint(0, N)
|
||||
while(os.path.exists(cur_path)):
|
||||
cur_path = path_template % np.random.randint(0, N)
|
||||
print('Creating directory: %s' % cur_path)
|
||||
os.mkdir(cur_path)
|
||||
return cur_path
|
||||
|
||||
|
||||
def lab2rgb_transpose(img_l, img_ab):
|
||||
''' INPUTS
|
||||
img_l 1xXxX [0,100]
|
||||
img_ab 2xXxX [-100,100]
|
||||
OUTPUTS
|
||||
returned value is XxXx3 '''
|
||||
pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0))
|
||||
# im = color.lab2rgb(pred_lab)
|
||||
im = cv2.cvtColor(pred_lab.astype('float32'),cv2.COLOR_LAB2RGB)
|
||||
pred_rgb = (np.clip(im, 0, 1) * 255).astype('uint8')
|
||||
return pred_rgb
|
||||
|
||||
|
||||
def rgb2lab_transpose(img_rgb):
|
||||
''' INPUTS
|
||||
img_rgb XxXx3
|
||||
OUTPUTS
|
||||
returned value is 3xXxX '''
|
||||
# im=color.rgb2lab(img_rgb)
|
||||
im = cv2.cvtColor(img_rgb.astype(np.float32)/255, cv2.COLOR_RGB2LAB)
|
||||
return im.transpose((2, 0, 1))
|
||||
|
||||
|
||||
class ColorizeImageBase():
|
||||
def __init__(self, Xd=256, Xfullres_max=10000):
|
||||
self.Xd = Xd
|
||||
self.img_l_set = False
|
||||
self.net_set = False
|
||||
self.Xfullres_max = Xfullres_max # maximum size of maximum dimension
|
||||
self.img_just_set = False # this will be true whenever image is just loaded
|
||||
# net_forward can set this to False if they want
|
||||
|
||||
def prep_net(self):
|
||||
raise Exception("Should be implemented by base class")
|
||||
|
||||
# ***** Image prepping *****
|
||||
def load_image(self, im):
|
||||
# rgb image [CxXdxXd]
|
||||
self.img_rgb_fullres = im.copy()
|
||||
self._set_img_lab_fullres_()
|
||||
|
||||
im = cv2.resize(im, (self.Xd, self.Xd))
|
||||
self.img_rgb = im.copy()
|
||||
# self.img_rgb = sp.misc.imresize(plt.imread(input_path),(self.Xd,self.Xd)).transpose((2,0,1))
|
||||
|
||||
self.img_l_set = True
|
||||
|
||||
# convert into lab space
|
||||
self._set_img_lab_()
|
||||
self._set_img_lab_mc_()
|
||||
|
||||
def set_image(self, input_image):
|
||||
self.img_rgb_fullres = input_image.copy()
|
||||
self._set_img_lab_fullres_()
|
||||
|
||||
self.img_l_set = True
|
||||
|
||||
self.img_rgb = input_image
|
||||
# convert into lab space
|
||||
self._set_img_lab_()
|
||||
self._set_img_lab_mc_()
|
||||
|
||||
def net_forward(self, input_ab, input_mask):
|
||||
# INPUTS
|
||||
# ab 2xXxX input color patches (non-normalized)
|
||||
# mask 1xXxX input mask, indicating which points have been provided
|
||||
# assumes self.img_l_mc has been set
|
||||
|
||||
if(not self.img_l_set):
|
||||
print('I need to have an image!')
|
||||
return -1
|
||||
if(not self.net_set):
|
||||
print('I need to have a net!')
|
||||
return -1
|
||||
|
||||
self.input_ab = input_ab
|
||||
self.input_ab_mc = (input_ab - self.ab_mean) / self.ab_norm
|
||||
self.input_mask = input_mask
|
||||
self.input_mask_mult = input_mask * self.mask_mult
|
||||
return 0
|
||||
|
||||
def get_result_PSNR(self, result=-1, return_SE_map=False):
|
||||
if np.array((result)).flatten()[0] == -1:
|
||||
cur_result = self.get_img_forward()
|
||||
else:
|
||||
cur_result = result.copy()
|
||||
SE_map = (1. * self.img_rgb - cur_result)**2
|
||||
cur_MSE = np.mean(SE_map)
|
||||
cur_PSNR = 20 * np.log10(255. / np.sqrt(cur_MSE))
|
||||
if return_SE_map:
|
||||
return(cur_PSNR, SE_map)
|
||||
else:
|
||||
return cur_PSNR
|
||||
|
||||
def get_img_forward(self):
|
||||
# get image with point estimate
|
||||
return self.output_rgb
|
||||
|
||||
def get_img_gray(self):
|
||||
# Get black and white image
|
||||
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
||||
|
||||
def get_img_gray_fullres(self):
|
||||
# Get black and white image
|
||||
return lab2rgb_transpose(self.img_l_fullres, np.zeros((2, self.img_l_fullres.shape[1], self.img_l_fullres.shape[2])))
|
||||
|
||||
def get_img_fullres(self):
|
||||
# This assumes self.img_l_fullres, self.output_ab are set.
|
||||
# Typically, this means that set_image() and net_forward()
|
||||
# have been called.
|
||||
# bilinear upsample
|
||||
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
|
||||
output_ab_fullres = zoom(self.output_ab, zoom_factor, order=1)
|
||||
|
||||
return lab2rgb_transpose(self.img_l_fullres, output_ab_fullres)
|
||||
|
||||
def get_input_img_fullres(self):
|
||||
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
|
||||
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=1)
|
||||
return lab2rgb_transpose(self.img_l_fullres, input_ab_fullres)
|
||||
|
||||
def get_input_img(self):
|
||||
return lab2rgb_transpose(self.img_l, self.input_ab)
|
||||
|
||||
def get_img_mask(self):
|
||||
# Get black and white image
|
||||
return lab2rgb_transpose(100. * (1 - self.input_mask), np.zeros((2, self.Xd, self.Xd)))
|
||||
|
||||
def get_img_mask_fullres(self):
|
||||
# Get black and white image
|
||||
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.input_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.input_ab.shape[2])
|
||||
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
|
||||
return lab2rgb_transpose(100. * (1 - input_mask_fullres), np.zeros((2, input_mask_fullres.shape[1], input_mask_fullres.shape[2])))
|
||||
|
||||
def get_sup_img(self):
|
||||
return lab2rgb_transpose(50 * self.input_mask, self.input_ab)
|
||||
|
||||
def get_sup_fullres(self):
|
||||
zoom_factor = (1, 1. * self.img_l_fullres.shape[1] / self.output_ab.shape[1], 1. * self.img_l_fullres.shape[2] / self.output_ab.shape[2])
|
||||
input_mask_fullres = zoom(self.input_mask, zoom_factor, order=0)
|
||||
input_ab_fullres = zoom(self.input_ab, zoom_factor, order=0)
|
||||
return lab2rgb_transpose(50 * input_mask_fullres, input_ab_fullres)
|
||||
|
||||
# ***** Private functions *****
|
||||
def _set_img_lab_fullres_(self):
|
||||
# adjust full resolution image to be within maximum dimension is within Xfullres_max
|
||||
Xfullres = self.img_rgb_fullres.shape[0]
|
||||
Yfullres = self.img_rgb_fullres.shape[1]
|
||||
if Xfullres > self.Xfullres_max or Yfullres > self.Xfullres_max:
|
||||
if Xfullres > Yfullres:
|
||||
zoom_factor = 1. * self.Xfullres_max / Xfullres
|
||||
else:
|
||||
zoom_factor = 1. * self.Xfullres_max / Yfullres
|
||||
self.img_rgb_fullres = zoom(self.img_rgb_fullres, (zoom_factor, zoom_factor, 1), order=1)
|
||||
|
||||
self.img_lab_fullres = cv2.cvtColor(self.img_rgb_fullres.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
|
||||
# self.img_lab_fullres = color.rgb2lab(self.img_rgb_fullres).transpose((2, 0, 1))
|
||||
self.img_l_fullres = self.img_lab_fullres[[0], :, :]
|
||||
self.img_ab_fullres = self.img_lab_fullres[1:, :, :]
|
||||
|
||||
def _set_img_lab_(self):
|
||||
# set self.img_lab from self.im_rgb
|
||||
self.img_lab = cv2.cvtColor(self.img_rgb.astype(np.float32) / 255, cv2.COLOR_RGB2LAB).transpose((2, 0, 1))
|
||||
# self.img_lab = color.rgb2lab(self.img_rgb).transpose((2, 0, 1))
|
||||
self.img_l = self.img_lab[[0], :, :]
|
||||
self.img_ab = self.img_lab[1:, :, :]
|
||||
|
||||
def _set_img_lab_mc_(self):
|
||||
# set self.img_lab_mc from self.img_lab
|
||||
# lab image, mean centered [XxYxX]
|
||||
self.img_lab_mc = self.img_lab / np.array((self.l_norm, self.ab_norm, self.ab_norm))[:, np.newaxis, np.newaxis] - np.array(
|
||||
(self.l_mean / self.l_norm, self.ab_mean / self.ab_norm, self.ab_mean / self.ab_norm))[:, np.newaxis, np.newaxis]
|
||||
self._set_img_l_()
|
||||
|
||||
def _set_img_l_(self):
|
||||
self.img_l_mc = self.img_lab_mc[[0], :, :]
|
||||
self.img_l_set = True
|
||||
|
||||
def _set_img_ab_(self):
|
||||
self.img_ab_mc = self.img_lab_mc[[1, 2], :, :]
|
||||
|
||||
def _set_out_ab_(self):
|
||||
self.output_lab = rgb2lab_transpose(self.output_rgb)
|
||||
self.output_ab = self.output_lab[1:, :, :]
|
||||
|
||||
|
||||
class ColorizeImageTorch(ColorizeImageBase):
|
||||
def __init__(self, Xd=256, maskcent=False):
|
||||
print('ColorizeImageTorch instantiated')
|
||||
ColorizeImageBase.__init__(self, Xd)
|
||||
self.l_norm = 1.
|
||||
self.ab_norm = 1.
|
||||
self.l_mean = 50.
|
||||
self.ab_mean = 0.
|
||||
self.mask_mult = 1.
|
||||
self.mask_cent = .5 if maskcent else 0
|
||||
|
||||
# Load grid properties
|
||||
self.pts_in_hull = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
|
||||
|
||||
# ***** Net preparation *****
|
||||
def prep_net(self, gpu_id=None, path='', dist=False):
|
||||
import torch
|
||||
import models.pytorch.model as model
|
||||
print('path = %s' % path)
|
||||
print('Model set! dist mode? ', dist)
|
||||
self.net = model.SIGGRAPHGenerator(dist=dist)
|
||||
state_dict = torch.load(path)
|
||||
if hasattr(state_dict, '_metadata'):
|
||||
del state_dict._metadata
|
||||
|
||||
# patch InstanceNorm checkpoints prior to 0.4
|
||||
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
||||
self.__patch_instance_norm_state_dict(state_dict, self.net, key.split('.'))
|
||||
self.net.load_state_dict(state_dict)
|
||||
if gpu_id != None:
|
||||
self.net.cuda()
|
||||
self.net.eval()
|
||||
self.net_set = True
|
||||
|
||||
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
||||
key = keys[i]
|
||||
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'running_mean' or key == 'running_var'):
|
||||
if getattr(module, key) is None:
|
||||
state_dict.pop('.'.join(keys))
|
||||
if module.__class__.__name__.startswith('InstanceNorm') and \
|
||||
(key == 'num_batches_tracked'):
|
||||
state_dict.pop('.'.join(keys))
|
||||
else:
|
||||
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
||||
|
||||
# ***** Call forward *****
|
||||
def net_forward(self, input_ab, input_mask):
|
||||
# INPUTS
|
||||
# ab 2xXxX input color patches (non-normalized)
|
||||
# mask 1xXxX input mask, indicating which points have been provided
|
||||
# assumes self.img_l_mc has been set
|
||||
|
||||
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
||||
return -1
|
||||
|
||||
# net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
|
||||
|
||||
# return prediction
|
||||
# self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
|
||||
# embed()
|
||||
output_ab = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)[0, :, :, :].cpu().data.numpy()
|
||||
self.output_rgb = lab2rgb_transpose(self.img_l, output_ab)
|
||||
# self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
|
||||
|
||||
self._set_out_ab_()
|
||||
return self.output_rgb
|
||||
|
||||
def get_img_forward(self):
|
||||
# get image with point estimate
|
||||
return self.output_rgb
|
||||
|
||||
def get_img_gray(self):
|
||||
# Get black and white image
|
||||
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
||||
|
||||
|
||||
class ColorizeImageTorchDist(ColorizeImageTorch):
|
||||
def __init__(self, Xd=256, maskcent=False):
|
||||
ColorizeImageTorch.__init__(self, Xd)
|
||||
self.dist_ab_set = False
|
||||
self.pts_grid = np.array(np.meshgrid(np.arange(-110, 120, 10), np.arange(-110, 120, 10))).reshape((2, 529)).T
|
||||
self.in_hull = np.ones(529, dtype=bool)
|
||||
self.AB = self.pts_grid.shape[0] # 529
|
||||
self.A = int(np.sqrt(self.AB)) # 23
|
||||
self.B = int(np.sqrt(self.AB)) # 23
|
||||
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
|
||||
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
|
||||
self.dist_entropy = np.zeros((self.Xd, self.Xd))
|
||||
self.mask_cent = .5 if maskcent else 0
|
||||
|
||||
def prep_net(self, gpu_id=None, path='', dist=True, S=.2):
|
||||
ColorizeImageTorch.prep_net(self, gpu_id=gpu_id, path=path, dist=dist)
|
||||
# set S somehow
|
||||
|
||||
def net_forward(self, input_ab, input_mask):
|
||||
# INPUTS
|
||||
# ab 2xXxX input color patches (non-normalized)
|
||||
# mask 1xXxX input mask, indicating which points have been provided
|
||||
# assumes self.img_l_mc has been set
|
||||
|
||||
# embed()
|
||||
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
||||
return -1
|
||||
|
||||
# set distribution
|
||||
(function_return, self.dist_ab) = self.net.forward(self.img_l_mc, self.input_ab_mc, self.input_mask_mult, self.mask_cent)
|
||||
function_return = function_return[0, :, :, :].cpu().data.numpy()
|
||||
self.dist_ab = self.dist_ab[0, :, :, :].cpu().data.numpy()
|
||||
self.dist_ab_set = True
|
||||
|
||||
# full grid, ABxXxX, AB = 529
|
||||
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
|
||||
|
||||
# gridded, AxBxXxX, A = 23
|
||||
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
|
||||
|
||||
# return
|
||||
return function_return
|
||||
|
||||
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
|
||||
# ''' Recommended colors at point (h,w)
|
||||
# Call this after calling net_forward
|
||||
# '''
|
||||
# if not self.dist_ab_set:
|
||||
# print('Need to set prediction first')
|
||||
# return 0
|
||||
#
|
||||
# # randomly sample from pdf
|
||||
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
|
||||
# cmf = cmf / cmf[-1]
|
||||
# cmf_bins = cmf
|
||||
#
|
||||
# # randomly sample N points
|
||||
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
|
||||
# inds = np.digitize(rnd_pts, bins=cmf_bins)
|
||||
# rnd_pts_ab = self.pts_in_hull[inds, :]
|
||||
#
|
||||
# # run k-means
|
||||
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
|
||||
#
|
||||
# # sort by cluster occupancy
|
||||
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
|
||||
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
|
||||
#
|
||||
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
|
||||
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
|
||||
#
|
||||
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
|
||||
# if return_conf:
|
||||
# return cluster_centers, cluster_per
|
||||
# else:
|
||||
# return cluster_centers
|
||||
|
||||
def compute_entropy(self):
|
||||
# compute the distribution entropy (really slow right now)
|
||||
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
|
||||
|
||||
# def plot_dist_grid(self, h, w):
|
||||
# # Plots distribution at a given point
|
||||
# plt.figure()
|
||||
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
|
||||
# plt.colorbar()
|
||||
# plt.ylabel('a')
|
||||
# plt.xlabel('b')
|
||||
|
||||
# def plot_dist_entropy(self):
|
||||
# # Plots distribution at a given point
|
||||
# plt.figure()
|
||||
# plt.imshow(-self.dist_entropy, interpolation='nearest')
|
||||
# plt.colorbar()
|
||||
|
||||
|
||||
class ColorizeImageCaffe(ColorizeImageBase):
|
||||
def __init__(self, Xd=256):
|
||||
print('ColorizeImageCaffe instantiated')
|
||||
ColorizeImageBase.__init__(self, Xd)
|
||||
self.l_norm = 1.
|
||||
self.ab_norm = 1.
|
||||
self.l_mean = 50.
|
||||
self.ab_mean = 0.
|
||||
self.mask_mult = 110.
|
||||
|
||||
self.pred_ab_layer = 'pred_ab' # predicted ab layer
|
||||
|
||||
# Load grid properties
|
||||
self.pts_in_hull_path = './data/color_bins/pts_in_hull.npy'
|
||||
self.pts_in_hull = np.load(self.pts_in_hull_path) # 313x2, in-gamut
|
||||
|
||||
# ***** Net preparation *****
|
||||
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path=''):
|
||||
import caffe
|
||||
print('gpu_id = %d, net_path = %s, model_path = %s' % (gpu_id, prototxt_path, caffemodel_path))
|
||||
if gpu_id == -1:
|
||||
caffe.set_mode_cpu()
|
||||
else:
|
||||
caffe.set_device(gpu_id)
|
||||
caffe.set_mode_gpu()
|
||||
self.gpu_id = gpu_id
|
||||
self.net = caffe.Net(prototxt_path, caffemodel_path, caffe.TEST)
|
||||
self.net_set = True
|
||||
|
||||
# automatically set cluster centers
|
||||
if len(self.net.params[self.pred_ab_layer][0].data[...].shape) == 4 and self.net.params[self.pred_ab_layer][0].data[...].shape[1] == 313:
|
||||
print('Setting ab cluster centers in layer: %s' % self.pred_ab_layer)
|
||||
self.net.params[self.pred_ab_layer][0].data[:, :, 0, 0] = self.pts_in_hull.T
|
||||
|
||||
# automatically set upsampling kernel
|
||||
for layer in self.net._layer_names:
|
||||
if layer[-3:] == '_us':
|
||||
print('Setting upsampling layer kernel: %s' % layer)
|
||||
self.net.params[layer][0].data[:, 0, :, :] = np.array(((.25, .5, .25, 0), (.5, 1., .5, 0), (.25, .5, .25, 0), (0, 0, 0, 0)))[np.newaxis, :, :]
|
||||
|
||||
# ***** Call forward *****
|
||||
def net_forward(self, input_ab, input_mask):
|
||||
# INPUTS
|
||||
# ab 2xXxX input color patches (non-normalized)
|
||||
# mask 1xXxX input mask, indicating which points have been provided
|
||||
# assumes self.img_l_mc has been set
|
||||
|
||||
if ColorizeImageBase.net_forward(self, input_ab, input_mask) == -1:
|
||||
return -1
|
||||
|
||||
net_input_prepped = np.concatenate((self.img_l_mc, self.input_ab_mc, self.input_mask_mult), axis=0)
|
||||
|
||||
self.net.blobs['data_l_ab_mask'].data[...] = net_input_prepped
|
||||
self.net.forward()
|
||||
|
||||
# return prediction
|
||||
self.output_rgb = lab2rgb_transpose(self.img_l, self.net.blobs[self.pred_ab_layer].data[0, :, :, :])
|
||||
|
||||
self._set_out_ab_()
|
||||
return self.output_rgb
|
||||
|
||||
def get_img_forward(self):
|
||||
# get image with point estimate
|
||||
return self.output_rgb
|
||||
|
||||
def get_img_gray(self):
|
||||
# Get black and white image
|
||||
return lab2rgb_transpose(self.img_l, np.zeros((2, self.Xd, self.Xd)))
|
||||
|
||||
|
||||
class ColorizeImageCaffeGlobDist(ColorizeImageCaffe):
|
||||
# Caffe colorization, with additional global histogram as input
|
||||
def __init__(self, Xd=256):
|
||||
ColorizeImageCaffe.__init__(self, Xd)
|
||||
self.glob_mask_mult = 1.
|
||||
self.glob_layer = 'glob_ab_313_mask'
|
||||
|
||||
def net_forward(self, input_ab, input_mask, glob_dist=-1):
|
||||
# glob_dist is 313 array, or -1
|
||||
if np.array(glob_dist).flatten()[0] == -1: # run without this, zero it out
|
||||
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = 0.
|
||||
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = 0.
|
||||
else: # run conditioned on global histogram
|
||||
self.net.blobs[self.glob_layer].data[0, :-1, 0, 0] = glob_dist
|
||||
self.net.blobs[self.glob_layer].data[0, -1, 0, 0] = self.glob_mask_mult
|
||||
|
||||
self.output_rgb = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
|
||||
self._set_out_ab_()
|
||||
return self.output_rgb
|
||||
|
||||
|
||||
class ColorizeImageCaffeDist(ColorizeImageCaffe):
|
||||
# caffe model which includes distribution prediction
|
||||
def __init__(self, Xd=256):
|
||||
ColorizeImageCaffe.__init__(self, Xd)
|
||||
self.dist_ab_set = False
|
||||
self.scale_S_layer = 'scale_S'
|
||||
self.dist_ab_S_layer = 'dist_ab_S' # softened distribution layer
|
||||
self.pts_grid = np.load('./data/color_bins/pts_grid.npy') # 529x2, all points
|
||||
self.in_hull = np.load('./data/color_bins/in_hull.npy') # 529 bool
|
||||
self.AB = self.pts_grid.shape[0] # 529
|
||||
self.A = int(np.sqrt(self.AB)) # 23
|
||||
self.B = int(np.sqrt(self.AB)) # 23
|
||||
self.dist_ab_full = np.zeros((self.AB, self.Xd, self.Xd))
|
||||
self.dist_ab_grid = np.zeros((self.A, self.B, self.Xd, self.Xd))
|
||||
self.dist_entropy = np.zeros((self.Xd, self.Xd))
|
||||
|
||||
def prep_net(self, gpu_id, prototxt_path='', caffemodel_path='', S=.2):
|
||||
ColorizeImageCaffe.prep_net(self, gpu_id, prototxt_path=prototxt_path, caffemodel_path=caffemodel_path)
|
||||
self.S = S
|
||||
self.net.params[self.scale_S_layer][0].data[...] = S
|
||||
|
||||
def net_forward(self, input_ab, input_mask):
|
||||
# INPUTS
|
||||
# ab 2xXxX input color patches (non-normalized)
|
||||
# mask 1xXxX input mask, indicating which points have been provided
|
||||
# assumes self.img_l_mc has been set
|
||||
|
||||
function_return = ColorizeImageCaffe.net_forward(self, input_ab, input_mask)
|
||||
if np.array(function_return).flatten()[0] == -1: # errored out
|
||||
return -1
|
||||
|
||||
# set distribution
|
||||
# in-gamut, CxXxX, C = 313
|
||||
self.dist_ab = self.net.blobs[self.dist_ab_S_layer].data[0, :, :, :]
|
||||
self.dist_ab_set = True
|
||||
|
||||
# full grid, ABxXxX, AB = 529
|
||||
self.dist_ab_full[self.in_hull, :, :] = self.dist_ab
|
||||
|
||||
# gridded, AxBxXxX, A = 23
|
||||
self.dist_ab_grid = self.dist_ab_full.reshape((self.A, self.B, self.Xd, self.Xd))
|
||||
|
||||
# return
|
||||
return function_return
|
||||
|
||||
# def get_ab_reccs(self, h, w, K=5, N=25000, return_conf=False):
|
||||
# ''' Recommended colors at point (h,w)
|
||||
# Call this after calling net_forward
|
||||
# '''
|
||||
# if not self.dist_ab_set:
|
||||
# print('Need to set prediction first')
|
||||
# return 0
|
||||
#
|
||||
# # randomly sample from pdf
|
||||
# cmf = np.cumsum(self.dist_ab[:, h, w]) # CMF
|
||||
# cmf = cmf / cmf[-1]
|
||||
# cmf_bins = cmf
|
||||
#
|
||||
# # randomly sample N points
|
||||
# rnd_pts = np.random.uniform(low=0, high=1.0, size=N)
|
||||
# inds = np.digitize(rnd_pts, bins=cmf_bins)
|
||||
# rnd_pts_ab = self.pts_in_hull[inds, :]
|
||||
#
|
||||
# # run k-means
|
||||
# kmeans = KMeans(n_clusters=K).fit(rnd_pts_ab)
|
||||
#
|
||||
# # sort by cluster occupancy
|
||||
# k_label_cnt = np.histogram(kmeans.labels_, np.arange(0, K + 1))[0]
|
||||
# k_inds = np.argsort(k_label_cnt, axis=0)[::-1]
|
||||
#
|
||||
# cluster_per = 1. * k_label_cnt[k_inds] / N # percentage of points within cluster
|
||||
# cluster_centers = kmeans.cluster_centers_[k_inds, :] # cluster centers
|
||||
#
|
||||
# # cluster_centers = np.random.uniform(low=-100,high=100,size=(N,2))
|
||||
# if return_conf:
|
||||
# return cluster_centers, cluster_per
|
||||
# else:
|
||||
# return cluster_centers
|
||||
|
||||
def compute_entropy(self):
|
||||
# compute the distribution entropy (really slow right now)
|
||||
self.dist_entropy = np.sum(self.dist_ab * np.log(self.dist_ab), axis=0)
|
||||
|
||||
# def plot_dist_grid(self, h, w):
|
||||
# Plots distribution at a given point
|
||||
# plt.figure()
|
||||
# plt.imshow(self.dist_ab_grid[:, :, h, w], extent=[-110, 110, 110, -110], interpolation='nearest')
|
||||
# plt.colorbar()
|
||||
# plt.ylabel('a')
|
||||
# plt.xlabel('b')
|
||||
|
||||
# def plot_dist_entropy(self):
|
||||
# Plots distribution at a given point
|
||||
# plt.figure()
|
||||
# plt.imshow(-self.dist_entropy, interpolation='nearest')
|
||||
# plt.colorbar()
|
BIN
gimp-plugins/ideepcolor/data/colorize_image.pyc
Normal file
BIN
gimp-plugins/ideepcolor/data/colorize_image.pyc
Normal file
Binary file not shown.
90
gimp-plugins/ideepcolor/data/lab_gamut.py
Normal file
90
gimp-plugins/ideepcolor/data/lab_gamut.py
Normal file
@ -0,0 +1,90 @@
|
||||
import numpy as np
|
||||
from skimage import color
|
||||
import warnings
|
||||
|
||||
|
||||
def qcolor2lab_1d(qc):
|
||||
# take 1d numpy array and do color conversion
|
||||
c = np.array([qc.red(), qc.green(), qc.blue()], np.uint8)
|
||||
return rgb2lab_1d(c)
|
||||
|
||||
|
||||
def rgb2lab_1d(in_rgb):
|
||||
# take 1d numpy array and do color conversion
|
||||
# print('in_rgb', in_rgb)
|
||||
return color.rgb2lab(in_rgb[np.newaxis, np.newaxis, :]).flatten()
|
||||
|
||||
|
||||
def lab2rgb_1d(in_lab, clip=True, dtype='uint8'):
|
||||
warnings.filterwarnings("ignore")
|
||||
tmp_rgb = color.lab2rgb(in_lab[np.newaxis, np.newaxis, :]).flatten()
|
||||
if clip:
|
||||
tmp_rgb = np.clip(tmp_rgb, 0, 1)
|
||||
if dtype == 'uint8':
|
||||
tmp_rgb = np.round(tmp_rgb * 255).astype('uint8')
|
||||
return tmp_rgb
|
||||
|
||||
|
||||
def snap_ab(input_l, input_rgb, return_type='rgb'):
|
||||
''' given an input lightness and rgb, snap the color into a region where l,a,b is in-gamut
|
||||
'''
|
||||
T = 20
|
||||
warnings.filterwarnings("ignore")
|
||||
input_lab = rgb2lab_1d(np.array(input_rgb)) # convert input to lab
|
||||
conv_lab = input_lab.copy() # keep ab from input
|
||||
for t in range(T):
|
||||
conv_lab[0] = input_l # overwrite input l with input ab
|
||||
old_lab = conv_lab
|
||||
tmp_rgb = color.lab2rgb(conv_lab[np.newaxis, np.newaxis, :]).flatten()
|
||||
tmp_rgb = np.clip(tmp_rgb, 0, 1)
|
||||
conv_lab = color.rgb2lab(tmp_rgb[np.newaxis, np.newaxis, :]).flatten()
|
||||
dif_lab = np.sum(np.abs(conv_lab - old_lab))
|
||||
if dif_lab < 1:
|
||||
break
|
||||
# print(conv_lab)
|
||||
|
||||
conv_rgb_ingamut = lab2rgb_1d(conv_lab, clip=True, dtype='uint8')
|
||||
if (return_type == 'rgb'):
|
||||
return conv_rgb_ingamut
|
||||
|
||||
elif(return_type == 'lab'):
|
||||
conv_lab_ingamut = rgb2lab_1d(conv_rgb_ingamut)
|
||||
return conv_lab_ingamut
|
||||
|
||||
|
||||
class abGrid():
|
||||
def __init__(self, gamut_size=110, D=1):
|
||||
self.D = D
|
||||
self.vals_b, self.vals_a = np.meshgrid(np.arange(-gamut_size, gamut_size + D, D),
|
||||
np.arange(-gamut_size, gamut_size + D, D))
|
||||
self.pts_full_grid = np.concatenate((self.vals_a[:, :, np.newaxis], self.vals_b[:, :, np.newaxis]), axis=2)
|
||||
self.A = self.pts_full_grid.shape[0]
|
||||
self.B = self.pts_full_grid.shape[1]
|
||||
self.AB = self.A * self.B
|
||||
self.gamut_size = gamut_size
|
||||
|
||||
def update_gamut(self, l_in):
|
||||
warnings.filterwarnings("ignore")
|
||||
thresh = 1.0
|
||||
pts_lab = np.concatenate((l_in + np.zeros((self.A, self.B, 1)), self.pts_full_grid), axis=2)
|
||||
self.pts_rgb = (255 * np.clip(color.lab2rgb(pts_lab), 0, 1)).astype('uint8')
|
||||
pts_lab_back = color.rgb2lab(self.pts_rgb)
|
||||
pts_lab_diff = np.linalg.norm(pts_lab - pts_lab_back, axis=2)
|
||||
|
||||
self.mask = pts_lab_diff < thresh
|
||||
mask3 = np.tile(self.mask[..., np.newaxis], [1, 1, 3])
|
||||
self.masked_rgb = self.pts_rgb.copy()
|
||||
self.masked_rgb[np.invert(mask3)] = 255
|
||||
return self.masked_rgb, self.mask
|
||||
|
||||
def ab2xy(self, a, b):
|
||||
y = self.gamut_size + a
|
||||
x = self.gamut_size + b
|
||||
# print('ab2xy (%d, %d) -> (%d, %d)' % (a, b, x, y))
|
||||
return x, y
|
||||
|
||||
def xy2ab(self, x, y):
|
||||
a = y - self.gamut_size
|
||||
b = x - self.gamut_size
|
||||
# print('xy2ab (%d, %d) -> (%d, %d)' % (x, y, a, b))
|
||||
return a, b
|
0
gimp-plugins/ideepcolor/models/__init__.py
Normal file
0
gimp-plugins/ideepcolor/models/__init__.py
Normal file
BIN
gimp-plugins/ideepcolor/models/__init__.pyc
Normal file
BIN
gimp-plugins/ideepcolor/models/__init__.pyc
Normal file
Binary file not shown.
0
gimp-plugins/ideepcolor/models/pytorch/__init__.py
Normal file
0
gimp-plugins/ideepcolor/models/pytorch/__init__.py
Normal file
BIN
gimp-plugins/ideepcolor/models/pytorch/__init__.pyc
Normal file
BIN
gimp-plugins/ideepcolor/models/pytorch/__init__.pyc
Normal file
Binary file not shown.
175
gimp-plugins/ideepcolor/models/pytorch/model.py
Normal file
175
gimp-plugins/ideepcolor/models/pytorch/model.py
Normal file
@ -0,0 +1,175 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SIGGRAPHGenerator(nn.Module):
|
||||
def __init__(self, dist=False):
|
||||
super(SIGGRAPHGenerator, self).__init__()
|
||||
self.dist = dist
|
||||
use_bias = True
|
||||
norm_layer = nn.BatchNorm2d
|
||||
|
||||
# Conv1
|
||||
model1 = [nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model1 += [nn.ReLU(True), ]
|
||||
model1 += [nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model1 += [nn.ReLU(True), ]
|
||||
model1 += [norm_layer(64), ]
|
||||
# add a subsampling operation
|
||||
|
||||
# Conv2
|
||||
model2 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model2 += [nn.ReLU(True), ]
|
||||
model2 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model2 += [nn.ReLU(True), ]
|
||||
model2 += [norm_layer(128), ]
|
||||
# add a subsampling layer operation
|
||||
|
||||
# Conv3
|
||||
model3 = [nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model3 += [nn.ReLU(True), ]
|
||||
model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model3 += [nn.ReLU(True), ]
|
||||
model3 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model3 += [nn.ReLU(True), ]
|
||||
model3 += [norm_layer(256), ]
|
||||
# add a subsampling layer operation
|
||||
|
||||
# Conv4
|
||||
model4 = [nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model4 += [nn.ReLU(True), ]
|
||||
model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model4 += [nn.ReLU(True), ]
|
||||
model4 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model4 += [nn.ReLU(True), ]
|
||||
model4 += [norm_layer(512), ]
|
||||
|
||||
# Conv5
|
||||
model5 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model5 += [nn.ReLU(True), ]
|
||||
model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model5 += [nn.ReLU(True), ]
|
||||
model5 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model5 += [nn.ReLU(True), ]
|
||||
model5 += [norm_layer(512), ]
|
||||
|
||||
# Conv6
|
||||
model6 = [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model6 += [nn.ReLU(True), ]
|
||||
model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model6 += [nn.ReLU(True), ]
|
||||
model6 += [nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=use_bias), ]
|
||||
model6 += [nn.ReLU(True), ]
|
||||
model6 += [norm_layer(512), ]
|
||||
|
||||
# Conv7
|
||||
model7 = [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model7 += [nn.ReLU(True), ]
|
||||
model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model7 += [nn.ReLU(True), ]
|
||||
model7 += [nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model7 += [nn.ReLU(True), ]
|
||||
model7 += [norm_layer(512), ]
|
||||
|
||||
# Conv7
|
||||
model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=use_bias)]
|
||||
model3short8 = [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
|
||||
model8 = [nn.ReLU(True), ]
|
||||
model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model8 += [nn.ReLU(True), ]
|
||||
model8 += [nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model8 += [nn.ReLU(True), ]
|
||||
model8 += [norm_layer(256), ]
|
||||
|
||||
# Conv9
|
||||
model9up = [nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
|
||||
model2short9 = [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
# add the two feature maps above
|
||||
|
||||
model9 = [nn.ReLU(True), ]
|
||||
model9 += [nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
model9 += [nn.ReLU(True), ]
|
||||
model9 += [norm_layer(128), ]
|
||||
|
||||
# Conv10
|
||||
model10up = [nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=use_bias), ]
|
||||
model1short10 = [nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=use_bias), ]
|
||||
# add the two feature maps above
|
||||
|
||||
model10 = [nn.ReLU(True), ]
|
||||
model10 += [nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=use_bias), ]
|
||||
model10 += [nn.LeakyReLU(negative_slope=.2), ]
|
||||
|
||||
# classification output
|
||||
model_class = [nn.Conv2d(256, 529, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
|
||||
|
||||
# regression output
|
||||
model_out = [nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=use_bias), ]
|
||||
model_out += [nn.Tanh()]
|
||||
|
||||
self.model1 = nn.Sequential(*model1)
|
||||
self.model2 = nn.Sequential(*model2)
|
||||
self.model3 = nn.Sequential(*model3)
|
||||
self.model4 = nn.Sequential(*model4)
|
||||
self.model5 = nn.Sequential(*model5)
|
||||
self.model6 = nn.Sequential(*model6)
|
||||
self.model7 = nn.Sequential(*model7)
|
||||
self.model8up = nn.Sequential(*model8up)
|
||||
self.model8 = nn.Sequential(*model8)
|
||||
self.model9up = nn.Sequential(*model9up)
|
||||
self.model9 = nn.Sequential(*model9)
|
||||
self.model10up = nn.Sequential(*model10up)
|
||||
self.model10 = nn.Sequential(*model10)
|
||||
self.model3short8 = nn.Sequential(*model3short8)
|
||||
self.model2short9 = nn.Sequential(*model2short9)
|
||||
self.model1short10 = nn.Sequential(*model1short10)
|
||||
|
||||
self.model_class = nn.Sequential(*model_class)
|
||||
self.model_out = nn.Sequential(*model_out)
|
||||
|
||||
self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='nearest'), ])
|
||||
self.softmax = nn.Sequential(*[nn.Softmax(dim=1), ])
|
||||
|
||||
def forward(self, input_A, input_B, mask_B, maskcent=0):
|
||||
# input_A \in [-50,+50]
|
||||
# input_B \in [-110, +110]
|
||||
# mask_B \in [0, +1.0]
|
||||
|
||||
input_A = torch.Tensor(input_A)[None, :, :, :]
|
||||
input_B = torch.Tensor(input_B)[None, :, :, :]
|
||||
mask_B = torch.Tensor(mask_B)[None, :, :, :]
|
||||
mask_B = mask_B - maskcent
|
||||
|
||||
# input_A = torch.Tensor(input_A).cuda()[None, :, :, :]
|
||||
# input_B = torch.Tensor(input_B).cuda()[None, :, :, :]
|
||||
# mask_B = torch.Tensor(mask_B).cuda()[None, :, :, :]
|
||||
|
||||
conv1_2 = self.model1(torch.cat((input_A / 100., input_B / 110., mask_B), dim=1))
|
||||
conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
|
||||
conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
|
||||
conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
|
||||
conv5_3 = self.model5(conv4_3)
|
||||
conv6_3 = self.model6(conv5_3)
|
||||
conv7_3 = self.model7(conv6_3)
|
||||
|
||||
conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
|
||||
conv8_3 = self.model8(conv8_up)
|
||||
|
||||
if(self.dist):
|
||||
out_cl = self.upsample4(self.softmax(self.model_class(conv8_3) * .2))
|
||||
|
||||
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
|
||||
conv9_3 = self.model9(conv9_up)
|
||||
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
|
||||
conv10_2 = self.model10(conv10_up)
|
||||
out_reg = self.model_out(conv10_2) * 110
|
||||
|
||||
return (out_reg * 110, out_cl)
|
||||
else:
|
||||
conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
|
||||
conv9_3 = self.model9(conv9_up)
|
||||
conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
|
||||
conv10_2 = self.model10(conv10_up)
|
||||
out_reg = self.model_out(conv10_2)
|
||||
return out_reg * 110
|
BIN
gimp-plugins/ideepcolor/models/pytorch/model.pyc
Normal file
BIN
gimp-plugins/ideepcolor/models/pytorch/model.pyc
Normal file
Binary file not shown.
@ -1,106 +1,51 @@
|
||||
import os
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||
|
||||
baseLoc = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
|
||||
from gimpfu import *
|
||||
import sys
|
||||
sys.path.extend([baseLoc+'gimpenv/lib/python2.7',baseLoc+'gimpenv/lib/python2.7/site-packages',baseLoc+'gimpenv/lib/python2.7/site-packages/setuptools',baseLoc+'monodepth2'])
|
||||
|
||||
sys.path.extend([baseLoc + 'gimpenv/lib/python2.7', baseLoc + 'gimpenv/lib/python2.7/site-packages',
|
||||
baseLoc + 'gimpenv/lib/python2.7/site-packages/setuptools', baseLoc + 'MiDaS'])
|
||||
|
||||
import PIL.Image as pil
|
||||
import networks
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
import os
|
||||
from run import run_depth
|
||||
from monodepth_net import MonoDepthNet
|
||||
import MiDaS_utils as MiDaS_utils
|
||||
import numpy as np
|
||||
import cv2
|
||||
# import matplotlib as mpl
|
||||
# import matplotlib.cm as cm
|
||||
|
||||
def getMonoDepth(input_image):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
loc=baseLoc+'monodepth2/'
|
||||
image = input_image / 255.0
|
||||
out = run_depth(image, baseLoc+'MiDaS/model.pt', MonoDepthNet, MiDaS_utils, target_w=640)
|
||||
out = np.repeat(out[:, :, np.newaxis], 3, axis=2)
|
||||
d1,d2 = input_image.shape[:2]
|
||||
out = cv2.resize(out,(d2,d1))
|
||||
# cv2.imwrite("/Users/kritiksoman/PycharmProjects/new/out.png", out)
|
||||
return out
|
||||
|
||||
model_path = os.path.join(loc+"models", 'mono+stereo_640x192')
|
||||
encoder_path = os.path.join(model_path, "encoder.pth")
|
||||
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
||||
|
||||
# LOADING PRETRAINED MODEL
|
||||
encoder = networks.ResnetEncoder(18, False)
|
||||
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
||||
|
||||
# extract the height and width of image that this model was trained with
|
||||
feed_height = loaded_dict_enc['height']
|
||||
feed_width = loaded_dict_enc['width']
|
||||
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
|
||||
encoder.load_state_dict(filtered_dict_enc)
|
||||
encoder.to(device)
|
||||
encoder.eval()
|
||||
|
||||
depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
|
||||
|
||||
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
||||
depth_decoder.load_state_dict(loaded_dict)
|
||||
|
||||
depth_decoder.to(device)
|
||||
depth_decoder.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
input_image = pil.fromarray(input_image)
|
||||
# input_image = pil.open(image_path).convert('RGB')
|
||||
original_width, original_height = input_image.size
|
||||
input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
|
||||
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
|
||||
|
||||
# PREDICTION
|
||||
input_image = input_image.to(device)
|
||||
features = encoder(input_image)
|
||||
outputs = depth_decoder(features)
|
||||
|
||||
disp = outputs[("disp", 0)]
|
||||
disp_resized = torch.nn.functional.interpolate(
|
||||
disp, (original_height, original_width), mode="bilinear", align_corners=False)
|
||||
|
||||
# Saving colormapped depth image
|
||||
disp_resized_np = disp_resized.squeeze().cpu().numpy()
|
||||
vmax = np.percentile(disp_resized_np, 95)
|
||||
vmin = disp_resized_np.min()
|
||||
disp_resized_np = vmin + (disp_resized_np - vmin) * (vmax - vmin) / (disp_resized_np.max() - vmin)
|
||||
disp_resized_np = (255 * (disp_resized_np - vmin) / (vmax - vmin)).astype(np.uint8)
|
||||
colormapped_im = cv2.applyColorMap(disp_resized_np, cv2.COLORMAP_HOT)
|
||||
colormapped_im = cv2.cvtColor(colormapped_im, cv2.COLOR_BGR2RGB)
|
||||
# normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
|
||||
# mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
|
||||
# colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
|
||||
return colormapped_im
|
||||
|
||||
def channelData(layer):#convert gimp image to numpy
|
||||
region=layer.get_pixel_rgn(0, 0, layer.width,layer.height)
|
||||
pixChars=region[:,:] # Take whole layer
|
||||
bpp=region.bpp
|
||||
def channelData(layer): # convert gimp image to numpy
|
||||
region = layer.get_pixel_rgn(0, 0, layer.width, layer.height)
|
||||
pixChars = region[:, :] # Take whole layer
|
||||
bpp = region.bpp
|
||||
# return np.frombuffer(pixChars,dtype=np.uint8).reshape(len(pixChars)/bpp,bpp)
|
||||
return np.frombuffer(pixChars,dtype=np.uint8).reshape(layer.height,layer.width,bpp)
|
||||
return np.frombuffer(pixChars, dtype=np.uint8).reshape(layer.height, layer.width, bpp)
|
||||
|
||||
def createResultLayer(image,name,result):
|
||||
rlBytes=np.uint8(result).tobytes();
|
||||
rl=gimp.Layer(image,name,image.width,image.height,image.active_layer.type,100,NORMAL_MODE)
|
||||
region=rl.get_pixel_rgn(0, 0, rl.width,rl.height,True)
|
||||
region[:,:]=rlBytes
|
||||
image.add_layer(rl,0)
|
||||
|
||||
def createResultLayer(image, name, result):
|
||||
rlBytes = np.uint8(result).tobytes();
|
||||
rl = gimp.Layer(image, name, image.width, image.height, image.active_layer.type, 100, NORMAL_MODE)
|
||||
region = rl.get_pixel_rgn(0, 0, rl.width, rl.height, True)
|
||||
region[:, :] = rlBytes
|
||||
image.add_layer(rl, 0)
|
||||
gimp.displays_flush()
|
||||
|
||||
def MonoDepth(img, layer) :
|
||||
|
||||
def MonoDepth(img, layer):
|
||||
gimp.progress_init("Generating disparity map for " + layer.name + "...")
|
||||
|
||||
imgmat = channelData(layer)
|
||||
cpy=getMonoDepth(imgmat)
|
||||
|
||||
createResultLayer(img,'new_output',cpy)
|
||||
|
||||
|
||||
cpy = getMonoDepth(imgmat)
|
||||
createResultLayer(img, 'new_output', cpy)
|
||||
|
||||
|
||||
register(
|
||||
@ -111,10 +56,10 @@ register(
|
||||
"Your",
|
||||
"2020",
|
||||
"MonoDepth...",
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[ (PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
],
|
||||
"*", # Alternately use RGB, RGB*, GRAY*, INDEXED etc.
|
||||
[(PF_IMAGE, "image", "Input image", None),
|
||||
(PF_DRAWABLE, "drawable", "Input drawable", None),
|
||||
],
|
||||
[],
|
||||
MonoDepth, menu="<Image>/Layer/GIML-ML")
|
||||
|
||||
|
@ -1,181 +0,0 @@
|
||||
Copyright © Niantic, Inc. 2018. Patent Pending.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
|
||||
|
||||
================================================================================
|
||||
|
||||
|
||||
|
||||
This Software is licensed under the terms of the following Monodepth2 license
|
||||
which allows for non-commercial use only. For any other use of the software not
|
||||
covered by the terms of this license, please contact partnerships@nianticlabs.com
|
||||
|
||||
|
||||
|
||||
================================================================================
|
||||
|
||||
|
||||
|
||||
Monodepth v2 License
|
||||
|
||||
|
||||
This Agreement is made by and between the Licensor and the Licensee as
|
||||
defined and identified below.
|
||||
|
||||
|
||||
1. Definitions.
|
||||
|
||||
In this Agreement (“the Agreement”) the following words shall have the
|
||||
following meanings:
|
||||
|
||||
"Authors" shall mean C. Godard, O. Mac Aodha, M. Firman, G. Brostow
|
||||
"Licensee" Shall mean the person or organization agreeing to use the
|
||||
Software in accordance with these terms and conditions.
|
||||
"Licensor" shall mean Niantic Inc., a company organized and existing under
|
||||
the laws of Delaware, whose principal place of business is at 1 Ferry Building,
|
||||
Suite 200, San Francisco, 94111.
|
||||
"Software" shall mean the MonoDepth v2 Software uploaded by Licensor to the
|
||||
GitHub repository at [URL] on [DATE] in source code or object code form and any
|
||||
accompanying documentation as well as any modifications or additions uploaded
|
||||
to the same GitHub repository by Licensor.
|
||||
|
||||
|
||||
2. License.
|
||||
|
||||
2.1 The Licensor has all necessary rights to grant a license under: (i)
|
||||
copyright and rights in the nature of copyright subsisting in the Software; and
|
||||
(ii) certain patent rights resulting from a patent application filed by the
|
||||
Licensor in the United States in connection with the Software. The Licensor
|
||||
grants the Licensee for the duration of this Agreement, a free of charge,
|
||||
non-sublicenseable, non-exclusive, non-transferable copyright and patent
|
||||
license (in consequence of said patent application) to use the Software for
|
||||
non-commercial purpose only, including teaching and research at educational
|
||||
institutions and research at not-for-profit research institutions in accordance
|
||||
with the provisions of this Agreement. Non-commercial use expressly excludes
|
||||
any profit-making or commercial activities, including without limitation sale,
|
||||
license, manufacture or development of commercial products, use in
|
||||
commercially-sponsored research, use at a laboratory or other facility owned or
|
||||
controlled (whether in whole or in part) by a commercial entity, provision of
|
||||
consulting service, use for or on behalf of any commercial entity, and use in
|
||||
research where a commercial party obtains rights to research results or any
|
||||
other benefit. Any use of the Software for any purpose other than
|
||||
non-commercial research shall automatically terminate this License.
|
||||
|
||||
|
||||
2.2 The Licensee is permitted to make modifications to the Software
|
||||
provided that any distribution of such modifications is in accordance with
|
||||
Clause 3.
|
||||
|
||||
2.3 Except as expressly permitted by this Agreement and save to the
|
||||
extent and in the circumstances expressly required to be permitted by law, the
|
||||
Licensee is not permitted to rent, lease, sell, offer to sell, or loan the
|
||||
Software or its associated documentation.
|
||||
|
||||
|
||||
3. Redistribution and modifications
|
||||
|
||||
3.1 The Licensee may reproduce and distribute copies of the Software, with
|
||||
or without modifications, in source format only and only to this same GitHub
|
||||
repository , and provided that any and every distribution is accompanied by an
|
||||
unmodified copy of this License and that the following copyright notice is
|
||||
always displayed in an obvious manner: Copyright © Niantic, Inc. 2018. All
|
||||
rights reserved.
|
||||
|
||||
|
||||
3.2 In the case where the Software has been modified, any distribution must
|
||||
include prominent notices indicating which files have been changed.
|
||||
|
||||
3.3 The Licensee shall cause any work that it distributes or publishes,
|
||||
that in whole or in part contains or is derived from the Software or any part
|
||||
thereof (“Work based on the Software”), to be licensed as a whole at no charge
|
||||
to all third parties entitled to a license to the Software under the terms of
|
||||
this License and on the same terms provided in this License.
|
||||
|
||||
|
||||
4. Duration.
|
||||
|
||||
This Agreement is effective until the Licensee terminates it by destroying
|
||||
the Software, any Work based on the Software, and its documentation together
|
||||
with all copies. It will also terminate automatically if the Licensee fails to
|
||||
abide by its terms. Upon automatic termination the Licensee agrees to destroy
|
||||
all copies of the Software, Work based on the Software, and its documentation.
|
||||
|
||||
|
||||
5. Disclaimer of Warranties.
|
||||
|
||||
The Software is provided as is. To the maximum extent permitted by law,
|
||||
Licensor provides no warranties or conditions of any kind, either express or
|
||||
implied, including without limitation, any warranties or condition of title,
|
||||
non-infringement or fitness for a particular purpose.
|
||||
|
||||
|
||||
6. LIMITATION OF LIABILITY.
|
||||
|
||||
IN NO EVENT SHALL THE LICENSOR AND/OR AUTHORS BE LIABLE FOR ANY DIRECT,
|
||||
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING
|
||||
BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
|
||||
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
|
||||
OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
||||
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
7. Indemnity.
|
||||
|
||||
The Licensee shall indemnify the Licensor and/or Authors against all third
|
||||
party claims that may be asserted against or suffered by the Licensor and/or
|
||||
Authors and which relate to use of the Software by the Licensee.
|
||||
|
||||
|
||||
8. Intellectual Property.
|
||||
|
||||
8.1 As between the Licensee and Licensor, copyright and all other
|
||||
intellectual property rights subsisting in or in connection with the Software
|
||||
and supporting information shall remain at all times the property of the
|
||||
Licensor. The Licensee shall acquire no rights in any such material except as
|
||||
expressly provided in this Agreement.
|
||||
|
||||
8.2 No permission is granted to use the trademarks or product names of the
|
||||
Licensor except as required for reasonable and customary use in describing the
|
||||
origin of the Software and for the purposes of abiding by the terms of Clause
|
||||
3.1.
|
||||
|
||||
8.3 The Licensee shall promptly notify the Licensor of any improvement or
|
||||
new use of the Software (“Improvements”) in sufficient detail for Licensor to
|
||||
evaluate the Improvements. The Licensee hereby grants the Licensor and its
|
||||
affiliates a non-exclusive, fully paid-up, royalty-free, irrevocable and
|
||||
perpetual license to all Improvements for non-commercial academic research and
|
||||
teaching purposes upon creation of such improvements.
|
||||
|
||||
8.4 The Licensee grants an exclusive first option to the Licensor to be
|
||||
exercised by the Licensor within three (3) years of the date of notification of
|
||||
an Improvement under Clause 8.3 to use any the Improvement for commercial
|
||||
purposes on terms to be negotiated and agreed by Licensee and Licensor in good
|
||||
faith within a period of six (6) months from the date of exercise of the said
|
||||
option (including without limitation any royalty share in net income from such
|
||||
commercialization payable to the Licensee, as the case may be).
|
||||
|
||||
|
||||
9. Acknowledgements.
|
||||
|
||||
The Licensee shall acknowledge the Authors and use of the Software in the
|
||||
publication of any work that uses, or results that are achieved through, the
|
||||
use of the Software. The following citation shall be included in the
|
||||
acknowledgement: “Digging Into Self-Supervised Monocular Depth Estimation,
|
||||
by C. Godard, O. Mac Aodha, M. Firman, G. Brostow, arXiv:1806.01260”.
|
||||
|
||||
|
||||
10. Governing Law.
|
||||
|
||||
This Agreement shall be governed by, construed and interpreted in
|
||||
accordance with English law and the parties submit to the exclusive
|
||||
jurisdiction of the English courts.
|
||||
|
||||
|
||||
11. Termination.
|
||||
|
||||
Upon termination of this Agreement, the licenses granted hereunder will
|
||||
terminate and Sections 5, 6, 7, 8, 9, 10 and 11 shall survive any termination
|
||||
of this Agreement.
|
@ -1,230 +0,0 @@
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from layers import disp_to_depth
|
||||
from utils import readlines
|
||||
from options import MonodepthOptions
|
||||
import datasets
|
||||
import networks
|
||||
|
||||
cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1)
|
||||
|
||||
|
||||
splits_dir = os.path.join(os.path.dirname(__file__), "splits")
|
||||
|
||||
# Models which were trained with stereo supervision were trained with a nominal
|
||||
# baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore,
|
||||
# to convert our stereo predictions to real-world scale we multiply our depths by 5.4.
|
||||
STEREO_SCALE_FACTOR = 5.4
|
||||
|
||||
|
||||
def compute_errors(gt, pred):
|
||||
"""Computation of error metrics between predicted and ground truth depths
|
||||
"""
|
||||
thresh = np.maximum((gt / pred), (pred / gt))
|
||||
a1 = (thresh < 1.25 ).mean()
|
||||
a2 = (thresh < 1.25 ** 2).mean()
|
||||
a3 = (thresh < 1.25 ** 3).mean()
|
||||
|
||||
rmse = (gt - pred) ** 2
|
||||
rmse = np.sqrt(rmse.mean())
|
||||
|
||||
rmse_log = (np.log(gt) - np.log(pred)) ** 2
|
||||
rmse_log = np.sqrt(rmse_log.mean())
|
||||
|
||||
abs_rel = np.mean(np.abs(gt - pred) / gt)
|
||||
|
||||
sq_rel = np.mean(((gt - pred) ** 2) / gt)
|
||||
|
||||
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
|
||||
|
||||
|
||||
def batch_post_process_disparity(l_disp, r_disp):
|
||||
"""Apply the disparity post-processing method as introduced in Monodepthv1
|
||||
"""
|
||||
_, h, w = l_disp.shape
|
||||
m_disp = 0.5 * (l_disp + r_disp)
|
||||
l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
|
||||
l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...]
|
||||
r_mask = l_mask[:, :, ::-1]
|
||||
return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp
|
||||
|
||||
|
||||
def evaluate(opt):
|
||||
"""Evaluates a pretrained model using a specified test set
|
||||
"""
|
||||
MIN_DEPTH = 1e-3
|
||||
MAX_DEPTH = 80
|
||||
|
||||
assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \
|
||||
"Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo"
|
||||
|
||||
if opt.ext_disp_to_eval is None:
|
||||
|
||||
opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder)
|
||||
|
||||
assert os.path.isdir(opt.load_weights_folder), \
|
||||
"Cannot find a folder at {}".format(opt.load_weights_folder)
|
||||
|
||||
print("-> Loading weights from {}".format(opt.load_weights_folder))
|
||||
|
||||
filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt"))
|
||||
encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth")
|
||||
decoder_path = os.path.join(opt.load_weights_folder, "depth.pth")
|
||||
|
||||
encoder_dict = torch.load(encoder_path)
|
||||
|
||||
dataset = datasets.KITTIRAWDataset(opt.data_path, filenames,
|
||||
encoder_dict['height'], encoder_dict['width'],
|
||||
[0], 4, is_train=False)
|
||||
dataloader = DataLoader(dataset, 16, shuffle=False, num_workers=opt.num_workers,
|
||||
pin_memory=True, drop_last=False)
|
||||
|
||||
encoder = networks.ResnetEncoder(opt.num_layers, False)
|
||||
depth_decoder = networks.DepthDecoder(encoder.num_ch_enc)
|
||||
|
||||
model_dict = encoder.state_dict()
|
||||
encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict})
|
||||
depth_decoder.load_state_dict(torch.load(decoder_path))
|
||||
|
||||
encoder.cuda()
|
||||
encoder.eval()
|
||||
depth_decoder.cuda()
|
||||
depth_decoder.eval()
|
||||
|
||||
pred_disps = []
|
||||
|
||||
print("-> Computing predictions with size {}x{}".format(
|
||||
encoder_dict['width'], encoder_dict['height']))
|
||||
|
||||
with torch.no_grad():
|
||||
for data in dataloader:
|
||||
input_color = data[("color", 0, 0)].cuda()
|
||||
|
||||
if opt.post_process:
|
||||
# Post-processed results require each image to have two forward passes
|
||||
input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0)
|
||||
|
||||
output = depth_decoder(encoder(input_color))
|
||||
|
||||
pred_disp, _ = disp_to_depth(output[("disp", 0)], opt.min_depth, opt.max_depth)
|
||||
pred_disp = pred_disp.cpu()[:, 0].numpy()
|
||||
|
||||
if opt.post_process:
|
||||
N = pred_disp.shape[0] // 2
|
||||
pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1])
|
||||
|
||||
pred_disps.append(pred_disp)
|
||||
|
||||
pred_disps = np.concatenate(pred_disps)
|
||||
|
||||
else:
|
||||
# Load predictions from file
|
||||
print("-> Loading predictions from {}".format(opt.ext_disp_to_eval))
|
||||
pred_disps = np.load(opt.ext_disp_to_eval)
|
||||
|
||||
if opt.eval_eigen_to_benchmark:
|
||||
eigen_to_benchmark_ids = np.load(
|
||||
os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy"))
|
||||
|
||||
pred_disps = pred_disps[eigen_to_benchmark_ids]
|
||||
|
||||
if opt.save_pred_disps:
|
||||
output_path = os.path.join(
|
||||
opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split))
|
||||
print("-> Saving predicted disparities to ", output_path)
|
||||
np.save(output_path, pred_disps)
|
||||
|
||||
if opt.no_eval:
|
||||
print("-> Evaluation disabled. Done.")
|
||||
quit()
|
||||
|
||||
elif opt.eval_split == 'benchmark':
|
||||
save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions")
|
||||
print("-> Saving out benchmark predictions to {}".format(save_dir))
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
for idx in range(len(pred_disps)):
|
||||
disp_resized = cv2.resize(pred_disps[idx], (1216, 352))
|
||||
depth = STEREO_SCALE_FACTOR / disp_resized
|
||||
depth = np.clip(depth, 0, 80)
|
||||
depth = np.uint16(depth * 256)
|
||||
save_path = os.path.join(save_dir, "{:010d}.png".format(idx))
|
||||
cv2.imwrite(save_path, depth)
|
||||
|
||||
print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.")
|
||||
quit()
|
||||
|
||||
gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz")
|
||||
gt_depths = np.load(gt_path, fix_imports=True, encoding='latin1')["data"]
|
||||
|
||||
print("-> Evaluating")
|
||||
|
||||
if opt.eval_stereo:
|
||||
print(" Stereo evaluation - "
|
||||
"disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR))
|
||||
opt.disable_median_scaling = True
|
||||
opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR
|
||||
else:
|
||||
print(" Mono evaluation - using median scaling")
|
||||
|
||||
errors = []
|
||||
ratios = []
|
||||
|
||||
for i in range(pred_disps.shape[0]):
|
||||
|
||||
gt_depth = gt_depths[i]
|
||||
gt_height, gt_width = gt_depth.shape[:2]
|
||||
|
||||
pred_disp = pred_disps[i]
|
||||
pred_disp = cv2.resize(pred_disp, (gt_width, gt_height))
|
||||
pred_depth = 1 / pred_disp
|
||||
|
||||
if opt.eval_split == "eigen":
|
||||
mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH)
|
||||
|
||||
crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height,
|
||||
0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32)
|
||||
crop_mask = np.zeros(mask.shape)
|
||||
crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
|
||||
mask = np.logical_and(mask, crop_mask)
|
||||
|
||||
else:
|
||||
mask = gt_depth > 0
|
||||
|
||||
pred_depth = pred_depth[mask]
|
||||
gt_depth = gt_depth[mask]
|
||||
|
||||
pred_depth *= opt.pred_depth_scale_factor
|
||||
if not opt.disable_median_scaling:
|
||||
ratio = np.median(gt_depth) / np.median(pred_depth)
|
||||
ratios.append(ratio)
|
||||
pred_depth *= ratio
|
||||
|
||||
pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH
|
||||
pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH
|
||||
|
||||
errors.append(compute_errors(gt_depth, pred_depth))
|
||||
|
||||
if not opt.disable_median_scaling:
|
||||
ratios = np.array(ratios)
|
||||
med = np.median(ratios)
|
||||
print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med)))
|
||||
|
||||
mean_errors = np.array(errors).mean(0)
|
||||
|
||||
print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3"))
|
||||
print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\")
|
||||
print("\n-> Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
options = MonodepthOptions()
|
||||
evaluate(options.parse())
|
@ -1,134 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from layers import transformation_from_parameters
|
||||
from utils import readlines
|
||||
from options import MonodepthOptions
|
||||
from datasets import KITTIOdomDataset
|
||||
import networks
|
||||
|
||||
|
||||
# from https://github.com/tinghuiz/SfMLearner
|
||||
def dump_xyz(source_to_target_transformations):
|
||||
xyzs = []
|
||||
cam_to_world = np.eye(4)
|
||||
xyzs.append(cam_to_world[:3, 3])
|
||||
for source_to_target_transformation in source_to_target_transformations:
|
||||
cam_to_world = np.dot(cam_to_world, source_to_target_transformation)
|
||||
xyzs.append(cam_to_world[:3, 3])
|
||||
return xyzs
|
||||
|
||||
|
||||
# from https://github.com/tinghuiz/SfMLearner
|
||||
def compute_ate(gtruth_xyz, pred_xyz_o):
|
||||
|
||||
# Make sure that the first matched frames align (no need for rotational alignment as
|
||||
# all the predicted/ground-truth snippets have been converted to use the same coordinate
|
||||
# system with the first frame of the snippet being the origin).
|
||||
offset = gtruth_xyz[0] - pred_xyz_o[0]
|
||||
pred_xyz = pred_xyz_o + offset[None, :]
|
||||
|
||||
# Optimize the scaling factor
|
||||
scale = np.sum(gtruth_xyz * pred_xyz) / np.sum(pred_xyz ** 2)
|
||||
alignment_error = pred_xyz * scale - gtruth_xyz
|
||||
rmse = np.sqrt(np.sum(alignment_error ** 2)) / gtruth_xyz.shape[0]
|
||||
return rmse
|
||||
|
||||
|
||||
def evaluate(opt):
|
||||
"""Evaluate odometry on the KITTI dataset
|
||||
"""
|
||||
assert os.path.isdir(opt.load_weights_folder), \
|
||||
"Cannot find a folder at {}".format(opt.load_weights_folder)
|
||||
|
||||
assert opt.eval_split == "odom_9" or opt.eval_split == "odom_10", \
|
||||
"eval_split should be either odom_9 or odom_10"
|
||||
|
||||
sequence_id = int(opt.eval_split.split("_")[1])
|
||||
|
||||
filenames = readlines(
|
||||
os.path.join(os.path.dirname(__file__), "splits", "odom",
|
||||
"test_files_{:02d}.txt".format(sequence_id)))
|
||||
|
||||
dataset = KITTIOdomDataset(opt.data_path, filenames, opt.height, opt.width,
|
||||
[0, 1], 4, is_train=False)
|
||||
dataloader = DataLoader(dataset, opt.batch_size, shuffle=False,
|
||||
num_workers=opt.num_workers, pin_memory=True, drop_last=False)
|
||||
|
||||
pose_encoder_path = os.path.join(opt.load_weights_folder, "pose_encoder.pth")
|
||||
pose_decoder_path = os.path.join(opt.load_weights_folder, "pose.pth")
|
||||
|
||||
pose_encoder = networks.ResnetEncoder(opt.num_layers, False, 2)
|
||||
pose_encoder.load_state_dict(torch.load(pose_encoder_path))
|
||||
|
||||
pose_decoder = networks.PoseDecoder(pose_encoder.num_ch_enc, 1, 2)
|
||||
pose_decoder.load_state_dict(torch.load(pose_decoder_path))
|
||||
|
||||
pose_encoder.cuda()
|
||||
pose_encoder.eval()
|
||||
pose_decoder.cuda()
|
||||
pose_decoder.eval()
|
||||
|
||||
pred_poses = []
|
||||
|
||||
print("-> Computing pose predictions")
|
||||
|
||||
opt.frame_ids = [0, 1] # pose network only takes two frames as input
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs in dataloader:
|
||||
for key, ipt in inputs.items():
|
||||
inputs[key] = ipt.cuda()
|
||||
|
||||
all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in opt.frame_ids], 1)
|
||||
|
||||
features = [pose_encoder(all_color_aug)]
|
||||
axisangle, translation = pose_decoder(features)
|
||||
|
||||
pred_poses.append(
|
||||
transformation_from_parameters(axisangle[:, 0], translation[:, 0]).cpu().numpy())
|
||||
|
||||
pred_poses = np.concatenate(pred_poses)
|
||||
|
||||
gt_poses_path = os.path.join(opt.data_path, "poses", "{:02d}.txt".format(sequence_id))
|
||||
gt_global_poses = np.loadtxt(gt_poses_path).reshape(-1, 3, 4)
|
||||
gt_global_poses = np.concatenate(
|
||||
(gt_global_poses, np.zeros((gt_global_poses.shape[0], 1, 4))), 1)
|
||||
gt_global_poses[:, 3, 3] = 1
|
||||
gt_xyzs = gt_global_poses[:, :3, 3]
|
||||
|
||||
gt_local_poses = []
|
||||
for i in range(1, len(gt_global_poses)):
|
||||
gt_local_poses.append(
|
||||
np.linalg.inv(np.dot(np.linalg.inv(gt_global_poses[i - 1]), gt_global_poses[i])))
|
||||
|
||||
ates = []
|
||||
num_frames = gt_xyzs.shape[0]
|
||||
track_length = 5
|
||||
for i in range(0, num_frames - 1):
|
||||
local_xyzs = np.array(dump_xyz(pred_poses[i:i + track_length - 1]))
|
||||
gt_local_xyzs = np.array(dump_xyz(gt_local_poses[i:i + track_length - 1]))
|
||||
|
||||
ates.append(compute_ate(gt_local_xyzs, local_xyzs))
|
||||
|
||||
print("\n Trajectory error: {:0.3f}, std: {:0.3f}\n".format(np.mean(ates), np.std(ates)))
|
||||
|
||||
save_path = os.path.join(opt.load_weights_folder, "poses.npy")
|
||||
np.save(save_path, pred_poses)
|
||||
print("-> Predictions saved to", save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
options = MonodepthOptions()
|
||||
evaluate(options.parse())
|
@ -1,65 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import PIL.Image as pil
|
||||
|
||||
from utils import readlines
|
||||
from kitti_utils import generate_depth_map
|
||||
|
||||
|
||||
def export_gt_depths_kitti():
|
||||
|
||||
parser = argparse.ArgumentParser(description='export_gt_depth')
|
||||
|
||||
parser.add_argument('--data_path',
|
||||
type=str,
|
||||
help='path to the root of the KITTI data',
|
||||
required=True)
|
||||
parser.add_argument('--split',
|
||||
type=str,
|
||||
help='which split to export gt from',
|
||||
required=True,
|
||||
choices=["eigen", "eigen_benchmark"])
|
||||
opt = parser.parse_args()
|
||||
|
||||
split_folder = os.path.join(os.path.dirname(__file__), "splits", opt.split)
|
||||
lines = readlines(os.path.join(split_folder, "test_files.txt"))
|
||||
|
||||
print("Exporting ground truth depths for {}".format(opt.split))
|
||||
|
||||
gt_depths = []
|
||||
for line in lines:
|
||||
|
||||
folder, frame_id, _ = line.split()
|
||||
frame_id = int(frame_id)
|
||||
|
||||
if opt.split == "eigen":
|
||||
calib_dir = os.path.join(opt.data_path, folder.split("/")[0])
|
||||
velo_filename = os.path.join(opt.data_path, folder,
|
||||
"velodyne_points/data", "{:010d}.bin".format(frame_id))
|
||||
gt_depth = generate_depth_map(calib_dir, velo_filename, 2, True)
|
||||
elif opt.split == "eigen_benchmark":
|
||||
gt_depth_path = os.path.join(opt.data_path, folder, "proj_depth",
|
||||
"groundtruth", "image_02", "{:010d}.png".format(frame_id))
|
||||
gt_depth = np.array(pil.open(gt_depth_path)).astype(np.float32) / 256
|
||||
|
||||
gt_depths.append(gt_depth.astype(np.float32))
|
||||
|
||||
output_path = os.path.join(split_folder, "gt_depths.npz")
|
||||
|
||||
print("Saving to {}".format(opt.split))
|
||||
|
||||
np.savez_compressed(output_path, data=np.array(gt_depths))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
export_gt_depths_kitti()
|
@ -1,269 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def disp_to_depth(disp, min_depth, max_depth):
|
||||
"""Convert network's sigmoid output into depth prediction
|
||||
The formula for this conversion is given in the 'additional considerations'
|
||||
section of the paper.
|
||||
"""
|
||||
min_disp = 1 / max_depth
|
||||
max_disp = 1 / min_depth
|
||||
scaled_disp = min_disp + (max_disp - min_disp) * disp
|
||||
depth = 1 / scaled_disp
|
||||
return scaled_disp, depth
|
||||
|
||||
|
||||
def transformation_from_parameters(axisangle, translation, invert=False):
|
||||
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
|
||||
"""
|
||||
R = rot_from_axisangle(axisangle)
|
||||
t = translation.clone()
|
||||
|
||||
if invert:
|
||||
R = R.transpose(1, 2)
|
||||
t *= -1
|
||||
|
||||
T = get_translation_matrix(t)
|
||||
|
||||
if invert:
|
||||
M = torch.matmul(R, T)
|
||||
else:
|
||||
M = torch.matmul(T, R)
|
||||
|
||||
return M
|
||||
|
||||
|
||||
def get_translation_matrix(translation_vector):
|
||||
"""Convert a translation vector into a 4x4 transformation matrix
|
||||
"""
|
||||
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
|
||||
|
||||
t = translation_vector.contiguous().view(-1, 3, 1)
|
||||
|
||||
T[:, 0, 0] = 1
|
||||
T[:, 1, 1] = 1
|
||||
T[:, 2, 2] = 1
|
||||
T[:, 3, 3] = 1
|
||||
T[:, :3, 3, None] = t
|
||||
|
||||
return T
|
||||
|
||||
|
||||
def rot_from_axisangle(vec):
|
||||
"""Convert an axisangle rotation into a 4x4 transformation matrix
|
||||
(adapted from https://github.com/Wallacoloo/printipi)
|
||||
Input 'vec' has to be Bx1x3
|
||||
"""
|
||||
angle = torch.norm(vec, 2, 2, True)
|
||||
axis = vec / (angle + 1e-7)
|
||||
|
||||
ca = torch.cos(angle)
|
||||
sa = torch.sin(angle)
|
||||
C = 1 - ca
|
||||
|
||||
x = axis[..., 0].unsqueeze(1)
|
||||
y = axis[..., 1].unsqueeze(1)
|
||||
z = axis[..., 2].unsqueeze(1)
|
||||
|
||||
xs = x * sa
|
||||
ys = y * sa
|
||||
zs = z * sa
|
||||
xC = x * C
|
||||
yC = y * C
|
||||
zC = z * C
|
||||
xyC = x * yC
|
||||
yzC = y * zC
|
||||
zxC = z * xC
|
||||
|
||||
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
|
||||
|
||||
rot[:, 0, 0] = torch.squeeze(x * xC + ca)
|
||||
rot[:, 0, 1] = torch.squeeze(xyC - zs)
|
||||
rot[:, 0, 2] = torch.squeeze(zxC + ys)
|
||||
rot[:, 1, 0] = torch.squeeze(xyC + zs)
|
||||
rot[:, 1, 1] = torch.squeeze(y * yC + ca)
|
||||
rot[:, 1, 2] = torch.squeeze(yzC - xs)
|
||||
rot[:, 2, 0] = torch.squeeze(zxC - ys)
|
||||
rot[:, 2, 1] = torch.squeeze(yzC + xs)
|
||||
rot[:, 2, 2] = torch.squeeze(z * zC + ca)
|
||||
rot[:, 3, 3] = 1
|
||||
|
||||
return rot
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Layer to perform a convolution followed by ELU
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(ConvBlock, self).__init__()
|
||||
|
||||
self.conv = Conv3x3(in_channels, out_channels)
|
||||
self.nonlin = nn.ELU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.nonlin(out)
|
||||
return out
|
||||
|
||||
|
||||
class Conv3x3(nn.Module):
|
||||
"""Layer to pad and convolve input
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, use_refl=True):
|
||||
super(Conv3x3, self).__init__()
|
||||
|
||||
if use_refl:
|
||||
self.pad = nn.ReflectionPad2d(1)
|
||||
else:
|
||||
self.pad = nn.ZeroPad2d(1)
|
||||
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.pad(x)
|
||||
out = self.conv(out)
|
||||
return out
|
||||
|
||||
|
||||
class BackprojectDepth(nn.Module):
|
||||
"""Layer to transform a depth image into a point cloud
|
||||
"""
|
||||
def __init__(self, batch_size, height, width):
|
||||
super(BackprojectDepth, self).__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
|
||||
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
|
||||
self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
|
||||
self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
|
||||
requires_grad=False)
|
||||
|
||||
self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
|
||||
requires_grad=False)
|
||||
|
||||
self.pix_coords = torch.unsqueeze(torch.stack(
|
||||
[self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
|
||||
self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
|
||||
self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
|
||||
requires_grad=False)
|
||||
|
||||
def forward(self, depth, inv_K):
|
||||
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
|
||||
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
|
||||
cam_points = torch.cat([cam_points, self.ones], 1)
|
||||
|
||||
return cam_points
|
||||
|
||||
|
||||
class Project3D(nn.Module):
|
||||
"""Layer which projects 3D points into a camera with intrinsics K and at position T
|
||||
"""
|
||||
def __init__(self, batch_size, height, width, eps=1e-7):
|
||||
super(Project3D, self).__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, points, K, T):
|
||||
P = torch.matmul(K, T)[:, :3, :]
|
||||
|
||||
cam_points = torch.matmul(P, points)
|
||||
|
||||
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
|
||||
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
|
||||
pix_coords = pix_coords.permute(0, 2, 3, 1)
|
||||
pix_coords[..., 0] /= self.width - 1
|
||||
pix_coords[..., 1] /= self.height - 1
|
||||
pix_coords = (pix_coords - 0.5) * 2
|
||||
return pix_coords
|
||||
|
||||
|
||||
def upsample(x):
|
||||
"""Upsample input tensor by a factor of 2
|
||||
"""
|
||||
return F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
|
||||
|
||||
def get_smooth_loss(disp, img):
|
||||
"""Computes the smoothness loss for a disparity image
|
||||
The color image is used for edge-aware smoothness
|
||||
"""
|
||||
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
|
||||
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
|
||||
|
||||
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
|
||||
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
|
||||
|
||||
grad_disp_x *= torch.exp(-grad_img_x)
|
||||
grad_disp_y *= torch.exp(-grad_img_y)
|
||||
|
||||
return grad_disp_x.mean() + grad_disp_y.mean()
|
||||
|
||||
|
||||
class SSIM(nn.Module):
|
||||
"""Layer to compute the SSIM loss between a pair of images
|
||||
"""
|
||||
def __init__(self):
|
||||
super(SSIM, self).__init__()
|
||||
self.mu_x_pool = nn.AvgPool2d(3, 1)
|
||||
self.mu_y_pool = nn.AvgPool2d(3, 1)
|
||||
self.sig_x_pool = nn.AvgPool2d(3, 1)
|
||||
self.sig_y_pool = nn.AvgPool2d(3, 1)
|
||||
self.sig_xy_pool = nn.AvgPool2d(3, 1)
|
||||
|
||||
self.refl = nn.ReflectionPad2d(1)
|
||||
|
||||
self.C1 = 0.01 ** 2
|
||||
self.C2 = 0.03 ** 2
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.refl(x)
|
||||
y = self.refl(y)
|
||||
|
||||
mu_x = self.mu_x_pool(x)
|
||||
mu_y = self.mu_y_pool(y)
|
||||
|
||||
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
|
||||
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
|
||||
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
|
||||
|
||||
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
|
||||
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
|
||||
|
||||
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
|
||||
|
||||
|
||||
def compute_depth_errors(gt, pred):
|
||||
"""Computation of error metrics between predicted and ground truth depths
|
||||
"""
|
||||
thresh = torch.max((gt / pred), (pred / gt))
|
||||
a1 = (thresh < 1.25 ).float().mean()
|
||||
a2 = (thresh < 1.25 ** 2).float().mean()
|
||||
a3 = (thresh < 1.25 ** 3).float().mean()
|
||||
|
||||
rmse = (gt - pred) ** 2
|
||||
rmse = torch.sqrt(rmse.mean())
|
||||
|
||||
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
|
||||
rmse_log = torch.sqrt(rmse_log.mean())
|
||||
|
||||
abs_rel = torch.mean(torch.abs(gt - pred) / gt)
|
||||
|
||||
sq_rel = torch.mean((gt - pred) ** 2 / gt)
|
||||
|
||||
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
|
Binary file not shown.
@ -1,4 +0,0 @@
|
||||
from .resnet_encoder import ResnetEncoder
|
||||
from .depth_decoder import DepthDecoder
|
||||
from .pose_decoder import PoseDecoder
|
||||
from .pose_cnn import PoseCNN
|
Binary file not shown.
@ -1,65 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from collections import OrderedDict
|
||||
from layers import *
|
||||
|
||||
|
||||
class DepthDecoder(nn.Module):
|
||||
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True):
|
||||
super(DepthDecoder, self).__init__()
|
||||
|
||||
self.num_output_channels = num_output_channels
|
||||
self.use_skips = use_skips
|
||||
self.upsample_mode = 'nearest'
|
||||
self.scales = scales
|
||||
|
||||
self.num_ch_enc = num_ch_enc
|
||||
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
|
||||
|
||||
# decoder
|
||||
self.convs = OrderedDict()
|
||||
for i in range(4, -1, -1):
|
||||
# upconv_0
|
||||
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
|
||||
num_ch_out = self.num_ch_dec[i]
|
||||
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
|
||||
|
||||
# upconv_1
|
||||
num_ch_in = self.num_ch_dec[i]
|
||||
if self.use_skips and i > 0:
|
||||
num_ch_in += self.num_ch_enc[i - 1]
|
||||
num_ch_out = self.num_ch_dec[i]
|
||||
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
|
||||
|
||||
for s in self.scales:
|
||||
self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
|
||||
|
||||
self.decoder = nn.ModuleList(list(self.convs.values()))
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, input_features):
|
||||
self.outputs = {}
|
||||
|
||||
# decoder
|
||||
x = input_features[-1]
|
||||
for i in range(4, -1, -1):
|
||||
x = self.convs[("upconv", i, 0)](x)
|
||||
x = [upsample(x)]
|
||||
if self.use_skips and i > 0:
|
||||
x += [input_features[i - 1]]
|
||||
x = torch.cat(x, 1)
|
||||
x = self.convs[("upconv", i, 1)](x)
|
||||
if i in self.scales:
|
||||
self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
|
||||
|
||||
return self.outputs
|
Binary file not shown.
@ -1,50 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class PoseCNN(nn.Module):
|
||||
def __init__(self, num_input_frames):
|
||||
super(PoseCNN, self).__init__()
|
||||
|
||||
self.num_input_frames = num_input_frames
|
||||
|
||||
self.convs = {}
|
||||
self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3)
|
||||
self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2)
|
||||
self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1)
|
||||
self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1)
|
||||
self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1)
|
||||
self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1)
|
||||
self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1)
|
||||
|
||||
self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1)
|
||||
|
||||
self.num_convs = len(self.convs)
|
||||
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
self.net = nn.ModuleList(list(self.convs.values()))
|
||||
|
||||
def forward(self, out):
|
||||
|
||||
for i in range(self.num_convs):
|
||||
out = self.convs[i](out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.pose_conv(out)
|
||||
out = out.mean(3).mean(2)
|
||||
|
||||
out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6)
|
||||
|
||||
axisangle = out[..., :3]
|
||||
translation = out[..., 3:]
|
||||
|
||||
return axisangle, translation
|
Binary file not shown.
@ -1,54 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class PoseDecoder(nn.Module):
|
||||
def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
|
||||
super(PoseDecoder, self).__init__()
|
||||
|
||||
self.num_ch_enc = num_ch_enc
|
||||
self.num_input_features = num_input_features
|
||||
|
||||
if num_frames_to_predict_for is None:
|
||||
num_frames_to_predict_for = num_input_features - 1
|
||||
self.num_frames_to_predict_for = num_frames_to_predict_for
|
||||
|
||||
self.convs = OrderedDict()
|
||||
self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
|
||||
self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
|
||||
self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
|
||||
self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.net = nn.ModuleList(list(self.convs.values()))
|
||||
|
||||
def forward(self, input_features):
|
||||
last_features = [f[-1] for f in input_features]
|
||||
|
||||
cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
|
||||
cat_features = torch.cat(cat_features, 1)
|
||||
|
||||
out = cat_features
|
||||
for i in range(3):
|
||||
out = self.convs[("pose", i)](out)
|
||||
if i != 2:
|
||||
out = self.relu(out)
|
||||
|
||||
out = out.mean(3).mean(2)
|
||||
|
||||
out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
|
||||
|
||||
axisangle = out[..., :3]
|
||||
translation = out[..., 3:]
|
||||
|
||||
return axisangle, translation
|
Binary file not shown.
@ -1,98 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
|
||||
class ResNetMultiImageInput(models.ResNet):
|
||||
"""Constructs a resnet model with varying number of input images.
|
||||
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||
"""
|
||||
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
|
||||
super(ResNetMultiImageInput, self).__init__(block, layers)
|
||||
self.inplanes = 64
|
||||
self.conv1 = nn.Conv2d(
|
||||
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
|
||||
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
|
||||
"""Constructs a ResNet model.
|
||||
Args:
|
||||
num_layers (int): Number of resnet layers. Must be 18 or 50
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
num_input_images (int): Number of frames stacked as input
|
||||
"""
|
||||
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
|
||||
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
|
||||
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
|
||||
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
|
||||
|
||||
if pretrained:
|
||||
loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
|
||||
loaded['conv1.weight'] = torch.cat(
|
||||
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images
|
||||
model.load_state_dict(loaded)
|
||||
return model
|
||||
|
||||
|
||||
class ResnetEncoder(nn.Module):
|
||||
"""Pytorch module for a resnet encoder
|
||||
"""
|
||||
def __init__(self, num_layers, pretrained, num_input_images=1):
|
||||
super(ResnetEncoder, self).__init__()
|
||||
|
||||
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
|
||||
|
||||
resnets = {18: models.resnet18,
|
||||
34: models.resnet34,
|
||||
50: models.resnet50,
|
||||
101: models.resnet101,
|
||||
152: models.resnet152}
|
||||
|
||||
if num_layers not in resnets:
|
||||
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
|
||||
|
||||
if num_input_images > 1:
|
||||
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
|
||||
else:
|
||||
self.encoder = resnets[num_layers](pretrained)
|
||||
|
||||
if num_layers > 34:
|
||||
self.num_ch_enc[1:] *= 4
|
||||
|
||||
def forward(self, input_image):
|
||||
self.features = []
|
||||
x = (input_image - 0.45) / 0.225
|
||||
x = self.encoder.conv1(x)
|
||||
x = self.encoder.bn1(x)
|
||||
self.features.append(self.encoder.relu(x))
|
||||
self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
|
||||
self.features.append(self.encoder.layer2(self.features[-1]))
|
||||
self.features.append(self.encoder.layer3(self.features[-1]))
|
||||
self.features.append(self.encoder.layer4(self.features[-1]))
|
||||
|
||||
return self.features
|
Binary file not shown.
@ -1,208 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import argparse
|
||||
|
||||
file_dir = os.path.dirname(__file__) # the directory that options.py resides in
|
||||
|
||||
|
||||
class MonodepthOptions:
|
||||
def __init__(self):
|
||||
self.parser = argparse.ArgumentParser(description="Monodepthv2 options")
|
||||
|
||||
# PATHS
|
||||
self.parser.add_argument("--data_path",
|
||||
type=str,
|
||||
help="path to the training data",
|
||||
default=os.path.join(file_dir, "kitti_data"))
|
||||
self.parser.add_argument("--log_dir",
|
||||
type=str,
|
||||
help="log directory",
|
||||
default=os.path.join(os.path.expanduser("~"), "tmp"))
|
||||
|
||||
# TRAINING options
|
||||
self.parser.add_argument("--model_name",
|
||||
type=str,
|
||||
help="the name of the folder to save the model in",
|
||||
default="mdp")
|
||||
self.parser.add_argument("--split",
|
||||
type=str,
|
||||
help="which training split to use",
|
||||
choices=["eigen_zhou", "eigen_full", "odom", "benchmark"],
|
||||
default="eigen_zhou")
|
||||
self.parser.add_argument("--num_layers",
|
||||
type=int,
|
||||
help="number of resnet layers",
|
||||
default=18,
|
||||
choices=[18, 34, 50, 101, 152])
|
||||
self.parser.add_argument("--dataset",
|
||||
type=str,
|
||||
help="dataset to train on",
|
||||
default="kitti",
|
||||
choices=["kitti", "kitti_odom", "kitti_depth", "kitti_test"])
|
||||
self.parser.add_argument("--png",
|
||||
help="if set, trains from raw KITTI png files (instead of jpgs)",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--height",
|
||||
type=int,
|
||||
help="input image height",
|
||||
default=192)
|
||||
self.parser.add_argument("--width",
|
||||
type=int,
|
||||
help="input image width",
|
||||
default=640)
|
||||
self.parser.add_argument("--disparity_smoothness",
|
||||
type=float,
|
||||
help="disparity smoothness weight",
|
||||
default=1e-3)
|
||||
self.parser.add_argument("--scales",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="scales used in the loss",
|
||||
default=[0, 1, 2, 3])
|
||||
self.parser.add_argument("--min_depth",
|
||||
type=float,
|
||||
help="minimum depth",
|
||||
default=0.1)
|
||||
self.parser.add_argument("--max_depth",
|
||||
type=float,
|
||||
help="maximum depth",
|
||||
default=100.0)
|
||||
self.parser.add_argument("--use_stereo",
|
||||
help="if set, uses stereo pair for training",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--frame_ids",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="frames to load",
|
||||
default=[0, -1, 1])
|
||||
|
||||
# OPTIMIZATION options
|
||||
self.parser.add_argument("--batch_size",
|
||||
type=int,
|
||||
help="batch size",
|
||||
default=12)
|
||||
self.parser.add_argument("--learning_rate",
|
||||
type=float,
|
||||
help="learning rate",
|
||||
default=1e-4)
|
||||
self.parser.add_argument("--num_epochs",
|
||||
type=int,
|
||||
help="number of epochs",
|
||||
default=20)
|
||||
self.parser.add_argument("--scheduler_step_size",
|
||||
type=int,
|
||||
help="step size of the scheduler",
|
||||
default=15)
|
||||
|
||||
# ABLATION options
|
||||
self.parser.add_argument("--v1_multiscale",
|
||||
help="if set, uses monodepth v1 multiscale",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--avg_reprojection",
|
||||
help="if set, uses average reprojection loss",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--disable_automasking",
|
||||
help="if set, doesn't do auto-masking",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--predictive_mask",
|
||||
help="if set, uses a predictive masking scheme as in Zhou et al",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--no_ssim",
|
||||
help="if set, disables ssim in the loss",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--weights_init",
|
||||
type=str,
|
||||
help="pretrained or scratch",
|
||||
default="pretrained",
|
||||
choices=["pretrained", "scratch"])
|
||||
self.parser.add_argument("--pose_model_input",
|
||||
type=str,
|
||||
help="how many images the pose network gets",
|
||||
default="pairs",
|
||||
choices=["pairs", "all"])
|
||||
self.parser.add_argument("--pose_model_type",
|
||||
type=str,
|
||||
help="normal or shared",
|
||||
default="separate_resnet",
|
||||
choices=["posecnn", "separate_resnet", "shared"])
|
||||
|
||||
# SYSTEM options
|
||||
self.parser.add_argument("--no_cuda",
|
||||
help="if set disables CUDA",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--num_workers",
|
||||
type=int,
|
||||
help="number of dataloader workers",
|
||||
default=12)
|
||||
|
||||
# LOADING options
|
||||
self.parser.add_argument("--load_weights_folder",
|
||||
type=str,
|
||||
help="name of model to load")
|
||||
self.parser.add_argument("--models_to_load",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="models to load",
|
||||
default=["encoder", "depth", "pose_encoder", "pose"])
|
||||
|
||||
# LOGGING options
|
||||
self.parser.add_argument("--log_frequency",
|
||||
type=int,
|
||||
help="number of batches between each tensorboard log",
|
||||
default=250)
|
||||
self.parser.add_argument("--save_frequency",
|
||||
type=int,
|
||||
help="number of epochs between each save",
|
||||
default=1)
|
||||
|
||||
# EVALUATION options
|
||||
self.parser.add_argument("--eval_stereo",
|
||||
help="if set evaluates in stereo mode",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--eval_mono",
|
||||
help="if set evaluates in mono mode",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--disable_median_scaling",
|
||||
help="if set disables median scaling in evaluation",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--pred_depth_scale_factor",
|
||||
help="if set multiplies predictions by this number",
|
||||
type=float,
|
||||
default=1)
|
||||
self.parser.add_argument("--ext_disp_to_eval",
|
||||
type=str,
|
||||
help="optional path to a .npy disparities file to evaluate")
|
||||
self.parser.add_argument("--eval_split",
|
||||
type=str,
|
||||
default="eigen",
|
||||
choices=[
|
||||
"eigen", "eigen_benchmark", "benchmark", "odom_9", "odom_10"],
|
||||
help="which split to run eval on")
|
||||
self.parser.add_argument("--save_pred_disps",
|
||||
help="if set saves predicted disparities",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--no_eval",
|
||||
help="if set disables evaluation",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--eval_eigen_to_benchmark",
|
||||
help="if set assume we are loading eigen results from npy but "
|
||||
"we want to evaluate using the new benchmark.",
|
||||
action="store_true")
|
||||
self.parser.add_argument("--eval_out_dir",
|
||||
help="if set will output the disparities to this folder",
|
||||
type=str)
|
||||
self.parser.add_argument("--post_process",
|
||||
help="if set will perform the flipping post processing "
|
||||
"from the original monodepth paper",
|
||||
action="store_true")
|
||||
|
||||
def parse(self):
|
||||
self.options = self.parser.parse_args()
|
||||
return self.options
|
@ -1,160 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import argparse
|
||||
import numpy as np
|
||||
import PIL.Image as pil
|
||||
# import cv2
|
||||
import matplotlib as mpl
|
||||
import matplotlib.cm as cm
|
||||
|
||||
import torch
|
||||
from torchvision import transforms, datasets
|
||||
|
||||
import networks
|
||||
from layers import disp_to_depth
|
||||
from utils import download_model_if_doesnt_exist
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Simple testing funtion for Monodepthv2 models.')
|
||||
|
||||
parser.add_argument('--image_path', type=str,
|
||||
help='path to a test image or folder of images', required=True)
|
||||
parser.add_argument('--model_name', type=str,
|
||||
help='name of a pretrained model to use',
|
||||
choices=[
|
||||
"mono_640x192",
|
||||
"stereo_640x192",
|
||||
"mono+stereo_640x192",
|
||||
"mono_no_pt_640x192",
|
||||
"stereo_no_pt_640x192",
|
||||
"mono+stereo_no_pt_640x192",
|
||||
"mono_1024x320",
|
||||
"stereo_1024x320",
|
||||
"mono+stereo_1024x320"])
|
||||
parser.add_argument('--ext', type=str,
|
||||
help='image extension to search for in folder', default="jpg")
|
||||
parser.add_argument("--no_cuda",
|
||||
help='if set, disables CUDA',
|
||||
action='store_true')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_simple(args):
|
||||
"""Function to predict for a single image or folder of images
|
||||
"""
|
||||
assert args.model_name is not None, \
|
||||
"You must specify the --model_name parameter; see README.md for an example"
|
||||
|
||||
if torch.cuda.is_available() and not args.no_cuda:
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
download_model_if_doesnt_exist(args.model_name)
|
||||
model_path = os.path.join("models", args.model_name)
|
||||
print("-> Loading model from ", model_path)
|
||||
encoder_path = os.path.join(model_path, "encoder.pth")
|
||||
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
||||
|
||||
# LOADING PRETRAINED MODEL
|
||||
print(" Loading pretrained encoder")
|
||||
encoder = networks.ResnetEncoder(18, False)
|
||||
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
||||
|
||||
# extract the height and width of image that this model was trained with
|
||||
feed_height = loaded_dict_enc['height']
|
||||
feed_width = loaded_dict_enc['width']
|
||||
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
|
||||
encoder.load_state_dict(filtered_dict_enc)
|
||||
encoder.to(device)
|
||||
encoder.eval()
|
||||
|
||||
print(" Loading pretrained decoder")
|
||||
depth_decoder = networks.DepthDecoder(
|
||||
num_ch_enc=encoder.num_ch_enc, scales=range(4))
|
||||
|
||||
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
||||
depth_decoder.load_state_dict(loaded_dict)
|
||||
|
||||
depth_decoder.to(device)
|
||||
depth_decoder.eval()
|
||||
|
||||
# FINDING INPUT IMAGES
|
||||
if os.path.isfile(args.image_path):
|
||||
# Only testing on a single image
|
||||
paths = [args.image_path]
|
||||
output_directory = os.path.dirname(args.image_path)
|
||||
elif os.path.isdir(args.image_path):
|
||||
# Searching folder for images
|
||||
paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext)))
|
||||
output_directory = args.image_path
|
||||
else:
|
||||
raise Exception("Can not find args.image_path: {}".format(args.image_path))
|
||||
|
||||
print("-> Predicting on {:d} test images".format(len(paths)))
|
||||
|
||||
# PREDICTING ON EACH IMAGE IN TURN
|
||||
with torch.no_grad():
|
||||
for idx, image_path in enumerate(paths):
|
||||
|
||||
if image_path.endswith("_disp.jpg"):
|
||||
# don't try to predict disparity for a disparity image!
|
||||
continue
|
||||
|
||||
# Load image and preprocess
|
||||
# input_image = cv2.imread(image_path)
|
||||
# input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
|
||||
input_image = pil.open(image_path).convert('RGB')
|
||||
original_width, original_height = input_image.size
|
||||
# input_image = cv2.resize(input_image, (feed_width, feed_height))
|
||||
input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS)
|
||||
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
|
||||
|
||||
# PREDICTION
|
||||
input_image = input_image.to(device)
|
||||
features = encoder(input_image)
|
||||
outputs = depth_decoder(features)
|
||||
|
||||
disp = outputs[("disp", 0)]
|
||||
disp_resized = torch.nn.functional.interpolate(
|
||||
disp, (original_height, original_width), mode="bilinear", align_corners=False)
|
||||
|
||||
# Saving numpy file
|
||||
output_name = os.path.splitext(os.path.basename(image_path))[0]
|
||||
name_dest_npy = os.path.join(output_directory, "{}_disp.npy".format(output_name))
|
||||
scaled_disp, _ = disp_to_depth(disp, 0.1, 100)
|
||||
np.save(name_dest_npy, scaled_disp.cpu().numpy())
|
||||
|
||||
# Saving colormapped depth image
|
||||
disp_resized_np = disp_resized.squeeze().cpu().numpy()
|
||||
vmax = np.percentile(disp_resized_np, 95)
|
||||
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
|
||||
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
|
||||
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
|
||||
im = pil.fromarray(colormapped_im)
|
||||
|
||||
name_dest_im = os.path.join(output_directory, "{}_disp.jpeg".format(output_name))
|
||||
im.save(name_dest_im)
|
||||
# cv2.imwrite('/Users/kritiksoman/Downloads/gimp-plugins/out5.jpg',cv2.cvtColor(colormapped_im, cv2.COLOR_RGB2BGR))
|
||||
|
||||
print(" Processed {:d} of {:d} images - saved prediction to {}".format(
|
||||
idx + 1, len(paths), name_dest_im))
|
||||
|
||||
print('-> Done!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
test_simple(args)
|
@ -1,18 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from trainer import Trainer
|
||||
from options import MonodepthOptions
|
||||
|
||||
options = MonodepthOptions()
|
||||
opts = options.parse()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trainer = Trainer(opts)
|
||||
trainer.train()
|
@ -1,630 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import json
|
||||
|
||||
from utils import *
|
||||
from kitti_utils import *
|
||||
from layers import *
|
||||
|
||||
import datasets
|
||||
import networks
|
||||
from IPython import embed
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, options):
|
||||
self.opt = options
|
||||
self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)
|
||||
|
||||
# checking height and width are multiples of 32
|
||||
assert self.opt.height % 32 == 0, "'height' must be a multiple of 32"
|
||||
assert self.opt.width % 32 == 0, "'width' must be a multiple of 32"
|
||||
|
||||
self.models = {}
|
||||
self.parameters_to_train = []
|
||||
|
||||
self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")
|
||||
|
||||
self.num_scales = len(self.opt.scales)
|
||||
self.num_input_frames = len(self.opt.frame_ids)
|
||||
self.num_pose_frames = 2 if self.opt.pose_model_input == "pairs" else self.num_input_frames
|
||||
|
||||
assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0"
|
||||
|
||||
self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0])
|
||||
|
||||
if self.opt.use_stereo:
|
||||
self.opt.frame_ids.append("s")
|
||||
|
||||
self.models["encoder"] = networks.ResnetEncoder(
|
||||
self.opt.num_layers, self.opt.weights_init == "pretrained")
|
||||
self.models["encoder"].to(self.device)
|
||||
self.parameters_to_train += list(self.models["encoder"].parameters())
|
||||
|
||||
self.models["depth"] = networks.DepthDecoder(
|
||||
self.models["encoder"].num_ch_enc, self.opt.scales)
|
||||
self.models["depth"].to(self.device)
|
||||
self.parameters_to_train += list(self.models["depth"].parameters())
|
||||
|
||||
if self.use_pose_net:
|
||||
if self.opt.pose_model_type == "separate_resnet":
|
||||
self.models["pose_encoder"] = networks.ResnetEncoder(
|
||||
self.opt.num_layers,
|
||||
self.opt.weights_init == "pretrained",
|
||||
num_input_images=self.num_pose_frames)
|
||||
|
||||
self.models["pose_encoder"].to(self.device)
|
||||
self.parameters_to_train += list(self.models["pose_encoder"].parameters())
|
||||
|
||||
self.models["pose"] = networks.PoseDecoder(
|
||||
self.models["pose_encoder"].num_ch_enc,
|
||||
num_input_features=1,
|
||||
num_frames_to_predict_for=2)
|
||||
|
||||
elif self.opt.pose_model_type == "shared":
|
||||
self.models["pose"] = networks.PoseDecoder(
|
||||
self.models["encoder"].num_ch_enc, self.num_pose_frames)
|
||||
|
||||
elif self.opt.pose_model_type == "posecnn":
|
||||
self.models["pose"] = networks.PoseCNN(
|
||||
self.num_input_frames if self.opt.pose_model_input == "all" else 2)
|
||||
|
||||
self.models["pose"].to(self.device)
|
||||
self.parameters_to_train += list(self.models["pose"].parameters())
|
||||
|
||||
if self.opt.predictive_mask:
|
||||
assert self.opt.disable_automasking, \
|
||||
"When using predictive_mask, please disable automasking with --disable_automasking"
|
||||
|
||||
# Our implementation of the predictive masking baseline has the the same architecture
|
||||
# as our depth decoder. We predict a separate mask for each source frame.
|
||||
self.models["predictive_mask"] = networks.DepthDecoder(
|
||||
self.models["encoder"].num_ch_enc, self.opt.scales,
|
||||
num_output_channels=(len(self.opt.frame_ids) - 1))
|
||||
self.models["predictive_mask"].to(self.device)
|
||||
self.parameters_to_train += list(self.models["predictive_mask"].parameters())
|
||||
|
||||
self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate)
|
||||
self.model_lr_scheduler = optim.lr_scheduler.StepLR(
|
||||
self.model_optimizer, self.opt.scheduler_step_size, 0.1)
|
||||
|
||||
if self.opt.load_weights_folder is not None:
|
||||
self.load_model()
|
||||
|
||||
print("Training model named:\n ", self.opt.model_name)
|
||||
print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir)
|
||||
print("Training is using:\n ", self.device)
|
||||
|
||||
# data
|
||||
datasets_dict = {"kitti": datasets.KITTIRAWDataset,
|
||||
"kitti_odom": datasets.KITTIOdomDataset}
|
||||
self.dataset = datasets_dict[self.opt.dataset]
|
||||
|
||||
fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")
|
||||
|
||||
train_filenames = readlines(fpath.format("train"))
|
||||
val_filenames = readlines(fpath.format("val"))
|
||||
img_ext = '.png' if self.opt.png else '.jpg'
|
||||
|
||||
num_train_samples = len(train_filenames)
|
||||
self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs
|
||||
|
||||
train_dataset = self.dataset(
|
||||
self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
|
||||
self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
|
||||
self.train_loader = DataLoader(
|
||||
train_dataset, self.opt.batch_size, True,
|
||||
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
|
||||
val_dataset = self.dataset(
|
||||
self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
|
||||
self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
|
||||
self.val_loader = DataLoader(
|
||||
val_dataset, self.opt.batch_size, True,
|
||||
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
|
||||
self.val_iter = iter(self.val_loader)
|
||||
|
||||
self.writers = {}
|
||||
for mode in ["train", "val"]:
|
||||
self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
|
||||
|
||||
if not self.opt.no_ssim:
|
||||
self.ssim = SSIM()
|
||||
self.ssim.to(self.device)
|
||||
|
||||
self.backproject_depth = {}
|
||||
self.project_3d = {}
|
||||
for scale in self.opt.scales:
|
||||
h = self.opt.height // (2 ** scale)
|
||||
w = self.opt.width // (2 ** scale)
|
||||
|
||||
self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w)
|
||||
self.backproject_depth[scale].to(self.device)
|
||||
|
||||
self.project_3d[scale] = Project3D(self.opt.batch_size, h, w)
|
||||
self.project_3d[scale].to(self.device)
|
||||
|
||||
self.depth_metric_names = [
|
||||
"de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"]
|
||||
|
||||
print("Using split:\n ", self.opt.split)
|
||||
print("There are {:d} training items and {:d} validation items\n".format(
|
||||
len(train_dataset), len(val_dataset)))
|
||||
|
||||
self.save_opts()
|
||||
|
||||
def set_train(self):
|
||||
"""Convert all models to training mode
|
||||
"""
|
||||
for m in self.models.values():
|
||||
m.train()
|
||||
|
||||
def set_eval(self):
|
||||
"""Convert all models to testing/evaluation mode
|
||||
"""
|
||||
for m in self.models.values():
|
||||
m.eval()
|
||||
|
||||
def train(self):
|
||||
"""Run the entire training pipeline
|
||||
"""
|
||||
self.epoch = 0
|
||||
self.step = 0
|
||||
self.start_time = time.time()
|
||||
for self.epoch in range(self.opt.num_epochs):
|
||||
self.run_epoch()
|
||||
if (self.epoch + 1) % self.opt.save_frequency == 0:
|
||||
self.save_model()
|
||||
|
||||
def run_epoch(self):
|
||||
"""Run a single epoch of training and validation
|
||||
"""
|
||||
self.model_lr_scheduler.step()
|
||||
|
||||
print("Training")
|
||||
self.set_train()
|
||||
|
||||
for batch_idx, inputs in enumerate(self.train_loader):
|
||||
|
||||
before_op_time = time.time()
|
||||
|
||||
outputs, losses = self.process_batch(inputs)
|
||||
|
||||
self.model_optimizer.zero_grad()
|
||||
losses["loss"].backward()
|
||||
self.model_optimizer.step()
|
||||
|
||||
duration = time.time() - before_op_time
|
||||
|
||||
# log less frequently after the first 2000 steps to save time & disk space
|
||||
early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 2000
|
||||
late_phase = self.step % 2000 == 0
|
||||
|
||||
if early_phase or late_phase:
|
||||
self.log_time(batch_idx, duration, losses["loss"].cpu().data)
|
||||
|
||||
if "depth_gt" in inputs:
|
||||
self.compute_depth_losses(inputs, outputs, losses)
|
||||
|
||||
self.log("train", inputs, outputs, losses)
|
||||
self.val()
|
||||
|
||||
self.step += 1
|
||||
|
||||
def process_batch(self, inputs):
|
||||
"""Pass a minibatch through the network and generate images and losses
|
||||
"""
|
||||
for key, ipt in inputs.items():
|
||||
inputs[key] = ipt.to(self.device)
|
||||
|
||||
if self.opt.pose_model_type == "shared":
|
||||
# If we are using a shared encoder for both depth and pose (as advocated
|
||||
# in monodepthv1), then all images are fed separately through the depth encoder.
|
||||
all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in self.opt.frame_ids])
|
||||
all_features = self.models["encoder"](all_color_aug)
|
||||
all_features = [torch.split(f, self.opt.batch_size) for f in all_features]
|
||||
|
||||
features = {}
|
||||
for i, k in enumerate(self.opt.frame_ids):
|
||||
features[k] = [f[i] for f in all_features]
|
||||
|
||||
outputs = self.models["depth"](features[0])
|
||||
else:
|
||||
# Otherwise, we only feed the image with frame_id 0 through the depth encoder
|
||||
features = self.models["encoder"](inputs["color_aug", 0, 0])
|
||||
outputs = self.models["depth"](features)
|
||||
|
||||
if self.opt.predictive_mask:
|
||||
outputs["predictive_mask"] = self.models["predictive_mask"](features)
|
||||
|
||||
if self.use_pose_net:
|
||||
outputs.update(self.predict_poses(inputs, features))
|
||||
|
||||
self.generate_images_pred(inputs, outputs)
|
||||
losses = self.compute_losses(inputs, outputs)
|
||||
|
||||
return outputs, losses
|
||||
|
||||
def predict_poses(self, inputs, features):
|
||||
"""Predict poses between input frames for monocular sequences.
|
||||
"""
|
||||
outputs = {}
|
||||
if self.num_pose_frames == 2:
|
||||
# In this setting, we compute the pose to each source frame via a
|
||||
# separate forward pass through the pose network.
|
||||
|
||||
# select what features the pose network takes as input
|
||||
if self.opt.pose_model_type == "shared":
|
||||
pose_feats = {f_i: features[f_i] for f_i in self.opt.frame_ids}
|
||||
else:
|
||||
pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids}
|
||||
|
||||
for f_i in self.opt.frame_ids[1:]:
|
||||
if f_i != "s":
|
||||
# To maintain ordering we always pass frames in temporal order
|
||||
if f_i < 0:
|
||||
pose_inputs = [pose_feats[f_i], pose_feats[0]]
|
||||
else:
|
||||
pose_inputs = [pose_feats[0], pose_feats[f_i]]
|
||||
|
||||
if self.opt.pose_model_type == "separate_resnet":
|
||||
pose_inputs = [self.models["pose_encoder"](torch.cat(pose_inputs, 1))]
|
||||
elif self.opt.pose_model_type == "posecnn":
|
||||
pose_inputs = torch.cat(pose_inputs, 1)
|
||||
|
||||
axisangle, translation = self.models["pose"](pose_inputs)
|
||||
outputs[("axisangle", 0, f_i)] = axisangle
|
||||
outputs[("translation", 0, f_i)] = translation
|
||||
|
||||
# Invert the matrix if the frame id is negative
|
||||
outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
|
||||
axisangle[:, 0], translation[:, 0], invert=(f_i < 0))
|
||||
|
||||
else:
|
||||
# Here we input all frames to the pose net (and predict all poses) together
|
||||
if self.opt.pose_model_type in ["separate_resnet", "posecnn"]:
|
||||
pose_inputs = torch.cat(
|
||||
[inputs[("color_aug", i, 0)] for i in self.opt.frame_ids if i != "s"], 1)
|
||||
|
||||
if self.opt.pose_model_type == "separate_resnet":
|
||||
pose_inputs = [self.models["pose_encoder"](pose_inputs)]
|
||||
|
||||
elif self.opt.pose_model_type == "shared":
|
||||
pose_inputs = [features[i] for i in self.opt.frame_ids if i != "s"]
|
||||
|
||||
axisangle, translation = self.models["pose"](pose_inputs)
|
||||
|
||||
for i, f_i in enumerate(self.opt.frame_ids[1:]):
|
||||
if f_i != "s":
|
||||
outputs[("axisangle", 0, f_i)] = axisangle
|
||||
outputs[("translation", 0, f_i)] = translation
|
||||
outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters(
|
||||
axisangle[:, i], translation[:, i])
|
||||
|
||||
return outputs
|
||||
|
||||
def val(self):
|
||||
"""Validate the model on a single minibatch
|
||||
"""
|
||||
self.set_eval()
|
||||
try:
|
||||
inputs = self.val_iter.next()
|
||||
except StopIteration:
|
||||
self.val_iter = iter(self.val_loader)
|
||||
inputs = self.val_iter.next()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs, losses = self.process_batch(inputs)
|
||||
|
||||
if "depth_gt" in inputs:
|
||||
self.compute_depth_losses(inputs, outputs, losses)
|
||||
|
||||
self.log("val", inputs, outputs, losses)
|
||||
del inputs, outputs, losses
|
||||
|
||||
self.set_train()
|
||||
|
||||
def generate_images_pred(self, inputs, outputs):
|
||||
"""Generate the warped (reprojected) color images for a minibatch.
|
||||
Generated images are saved into the `outputs` dictionary.
|
||||
"""
|
||||
for scale in self.opt.scales:
|
||||
disp = outputs[("disp", scale)]
|
||||
if self.opt.v1_multiscale:
|
||||
source_scale = scale
|
||||
else:
|
||||
disp = F.interpolate(
|
||||
disp, [self.opt.height, self.opt.width], mode="bilinear", align_corners=False)
|
||||
source_scale = 0
|
||||
|
||||
_, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth)
|
||||
|
||||
outputs[("depth", 0, scale)] = depth
|
||||
|
||||
for i, frame_id in enumerate(self.opt.frame_ids[1:]):
|
||||
|
||||
if frame_id == "s":
|
||||
T = inputs["stereo_T"]
|
||||
else:
|
||||
T = outputs[("cam_T_cam", 0, frame_id)]
|
||||
|
||||
# from the authors of https://arxiv.org/abs/1712.00175
|
||||
if self.opt.pose_model_type == "posecnn":
|
||||
|
||||
axisangle = outputs[("axisangle", 0, frame_id)]
|
||||
translation = outputs[("translation", 0, frame_id)]
|
||||
|
||||
inv_depth = 1 / depth
|
||||
mean_inv_depth = inv_depth.mean(3, True).mean(2, True)
|
||||
|
||||
T = transformation_from_parameters(
|
||||
axisangle[:, 0], translation[:, 0] * mean_inv_depth[:, 0], frame_id < 0)
|
||||
|
||||
cam_points = self.backproject_depth[source_scale](
|
||||
depth, inputs[("inv_K", source_scale)])
|
||||
pix_coords = self.project_3d[source_scale](
|
||||
cam_points, inputs[("K", source_scale)], T)
|
||||
|
||||
outputs[("sample", frame_id, scale)] = pix_coords
|
||||
|
||||
outputs[("color", frame_id, scale)] = F.grid_sample(
|
||||
inputs[("color", frame_id, source_scale)],
|
||||
outputs[("sample", frame_id, scale)],
|
||||
padding_mode="border")
|
||||
|
||||
if not self.opt.disable_automasking:
|
||||
outputs[("color_identity", frame_id, scale)] = \
|
||||
inputs[("color", frame_id, source_scale)]
|
||||
|
||||
def compute_reprojection_loss(self, pred, target):
|
||||
"""Computes reprojection loss between a batch of predicted and target images
|
||||
"""
|
||||
abs_diff = torch.abs(target - pred)
|
||||
l1_loss = abs_diff.mean(1, True)
|
||||
|
||||
if self.opt.no_ssim:
|
||||
reprojection_loss = l1_loss
|
||||
else:
|
||||
ssim_loss = self.ssim(pred, target).mean(1, True)
|
||||
reprojection_loss = 0.85 * ssim_loss + 0.15 * l1_loss
|
||||
|
||||
return reprojection_loss
|
||||
|
||||
def compute_losses(self, inputs, outputs):
|
||||
"""Compute the reprojection and smoothness losses for a minibatch
|
||||
"""
|
||||
losses = {}
|
||||
total_loss = 0
|
||||
|
||||
for scale in self.opt.scales:
|
||||
loss = 0
|
||||
reprojection_losses = []
|
||||
|
||||
if self.opt.v1_multiscale:
|
||||
source_scale = scale
|
||||
else:
|
||||
source_scale = 0
|
||||
|
||||
disp = outputs[("disp", scale)]
|
||||
color = inputs[("color", 0, scale)]
|
||||
target = inputs[("color", 0, source_scale)]
|
||||
|
||||
for frame_id in self.opt.frame_ids[1:]:
|
||||
pred = outputs[("color", frame_id, scale)]
|
||||
reprojection_losses.append(self.compute_reprojection_loss(pred, target))
|
||||
|
||||
reprojection_losses = torch.cat(reprojection_losses, 1)
|
||||
|
||||
if not self.opt.disable_automasking:
|
||||
identity_reprojection_losses = []
|
||||
for frame_id in self.opt.frame_ids[1:]:
|
||||
pred = inputs[("color", frame_id, source_scale)]
|
||||
identity_reprojection_losses.append(
|
||||
self.compute_reprojection_loss(pred, target))
|
||||
|
||||
identity_reprojection_losses = torch.cat(identity_reprojection_losses, 1)
|
||||
|
||||
if self.opt.avg_reprojection:
|
||||
identity_reprojection_loss = identity_reprojection_losses.mean(1, keepdim=True)
|
||||
else:
|
||||
# save both images, and do min all at once below
|
||||
identity_reprojection_loss = identity_reprojection_losses
|
||||
|
||||
elif self.opt.predictive_mask:
|
||||
# use the predicted mask
|
||||
mask = outputs["predictive_mask"]["disp", scale]
|
||||
if not self.opt.v1_multiscale:
|
||||
mask = F.interpolate(
|
||||
mask, [self.opt.height, self.opt.width],
|
||||
mode="bilinear", align_corners=False)
|
||||
|
||||
reprojection_losses *= mask
|
||||
|
||||
# add a loss pushing mask to 1 (using nn.BCELoss for stability)
|
||||
weighting_loss = 0.2 * nn.BCELoss()(mask, torch.ones(mask.shape).cuda())
|
||||
loss += weighting_loss.mean()
|
||||
|
||||
if self.opt.avg_reprojection:
|
||||
reprojection_loss = reprojection_losses.mean(1, keepdim=True)
|
||||
else:
|
||||
reprojection_loss = reprojection_losses
|
||||
|
||||
if not self.opt.disable_automasking:
|
||||
# add random numbers to break ties
|
||||
identity_reprojection_loss += torch.randn(
|
||||
identity_reprojection_loss.shape).cuda() * 0.00001
|
||||
|
||||
combined = torch.cat((identity_reprojection_loss, reprojection_loss), dim=1)
|
||||
else:
|
||||
combined = reprojection_loss
|
||||
|
||||
if combined.shape[1] == 1:
|
||||
to_optimise = combined
|
||||
else:
|
||||
to_optimise, idxs = torch.min(combined, dim=1)
|
||||
|
||||
if not self.opt.disable_automasking:
|
||||
outputs["identity_selection/{}".format(scale)] = (
|
||||
idxs > identity_reprojection_loss.shape[1] - 1).float()
|
||||
|
||||
loss += to_optimise.mean()
|
||||
|
||||
mean_disp = disp.mean(2, True).mean(3, True)
|
||||
norm_disp = disp / (mean_disp + 1e-7)
|
||||
smooth_loss = get_smooth_loss(norm_disp, color)
|
||||
|
||||
loss += self.opt.disparity_smoothness * smooth_loss / (2 ** scale)
|
||||
total_loss += loss
|
||||
losses["loss/{}".format(scale)] = loss
|
||||
|
||||
total_loss /= self.num_scales
|
||||
losses["loss"] = total_loss
|
||||
return losses
|
||||
|
||||
def compute_depth_losses(self, inputs, outputs, losses):
|
||||
"""Compute depth metrics, to allow monitoring during training
|
||||
|
||||
This isn't particularly accurate as it averages over the entire batch,
|
||||
so is only used to give an indication of validation performance
|
||||
"""
|
||||
depth_pred = outputs[("depth", 0, 0)]
|
||||
depth_pred = torch.clamp(F.interpolate(
|
||||
depth_pred, [375, 1242], mode="bilinear", align_corners=False), 1e-3, 80)
|
||||
depth_pred = depth_pred.detach()
|
||||
|
||||
depth_gt = inputs["depth_gt"]
|
||||
mask = depth_gt > 0
|
||||
|
||||
# garg/eigen crop
|
||||
crop_mask = torch.zeros_like(mask)
|
||||
crop_mask[:, :, 153:371, 44:1197] = 1
|
||||
mask = mask * crop_mask
|
||||
|
||||
depth_gt = depth_gt[mask]
|
||||
depth_pred = depth_pred[mask]
|
||||
depth_pred *= torch.median(depth_gt) / torch.median(depth_pred)
|
||||
|
||||
depth_pred = torch.clamp(depth_pred, min=1e-3, max=80)
|
||||
|
||||
depth_errors = compute_depth_errors(depth_gt, depth_pred)
|
||||
|
||||
for i, metric in enumerate(self.depth_metric_names):
|
||||
losses[metric] = np.array(depth_errors[i].cpu())
|
||||
|
||||
def log_time(self, batch_idx, duration, loss):
|
||||
"""Print a logging statement to the terminal
|
||||
"""
|
||||
samples_per_sec = self.opt.batch_size / duration
|
||||
time_sofar = time.time() - self.start_time
|
||||
training_time_left = (
|
||||
self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0
|
||||
print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \
|
||||
" | loss: {:.5f} | time elapsed: {} | time left: {}"
|
||||
print(print_string.format(self.epoch, batch_idx, samples_per_sec, loss,
|
||||
sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left)))
|
||||
|
||||
def log(self, mode, inputs, outputs, losses):
|
||||
"""Write an event to the tensorboard events file
|
||||
"""
|
||||
writer = self.writers[mode]
|
||||
for l, v in losses.items():
|
||||
writer.add_scalar("{}".format(l), v, self.step)
|
||||
|
||||
for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images
|
||||
for s in self.opt.scales:
|
||||
for frame_id in self.opt.frame_ids:
|
||||
writer.add_image(
|
||||
"color_{}_{}/{}".format(frame_id, s, j),
|
||||
inputs[("color", frame_id, s)][j].data, self.step)
|
||||
if s == 0 and frame_id != 0:
|
||||
writer.add_image(
|
||||
"color_pred_{}_{}/{}".format(frame_id, s, j),
|
||||
outputs[("color", frame_id, s)][j].data, self.step)
|
||||
|
||||
writer.add_image(
|
||||
"disp_{}/{}".format(s, j),
|
||||
normalize_image(outputs[("disp", s)][j]), self.step)
|
||||
|
||||
if self.opt.predictive_mask:
|
||||
for f_idx, frame_id in enumerate(self.opt.frame_ids[1:]):
|
||||
writer.add_image(
|
||||
"predictive_mask_{}_{}/{}".format(frame_id, s, j),
|
||||
outputs["predictive_mask"][("disp", s)][j, f_idx][None, ...],
|
||||
self.step)
|
||||
|
||||
elif not self.opt.disable_automasking:
|
||||
writer.add_image(
|
||||
"automask_{}/{}".format(s, j),
|
||||
outputs["identity_selection/{}".format(s)][j][None, ...], self.step)
|
||||
|
||||
def save_opts(self):
|
||||
"""Save options to disk so we know what we ran this experiment with
|
||||
"""
|
||||
models_dir = os.path.join(self.log_path, "models")
|
||||
if not os.path.exists(models_dir):
|
||||
os.makedirs(models_dir)
|
||||
to_save = self.opt.__dict__.copy()
|
||||
|
||||
with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
|
||||
json.dump(to_save, f, indent=2)
|
||||
|
||||
def save_model(self):
|
||||
"""Save model weights to disk
|
||||
"""
|
||||
save_folder = os.path.join(self.log_path, "models", "weights_{}".format(self.epoch))
|
||||
if not os.path.exists(save_folder):
|
||||
os.makedirs(save_folder)
|
||||
|
||||
for model_name, model in self.models.items():
|
||||
save_path = os.path.join(save_folder, "{}.pth".format(model_name))
|
||||
to_save = model.state_dict()
|
||||
if model_name == 'encoder':
|
||||
# save the sizes - these are needed at prediction time
|
||||
to_save['height'] = self.opt.height
|
||||
to_save['width'] = self.opt.width
|
||||
to_save['use_stereo'] = self.opt.use_stereo
|
||||
torch.save(to_save, save_path)
|
||||
|
||||
save_path = os.path.join(save_folder, "{}.pth".format("adam"))
|
||||
torch.save(self.model_optimizer.state_dict(), save_path)
|
||||
|
||||
def load_model(self):
|
||||
"""Load model(s) from disk
|
||||
"""
|
||||
self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder)
|
||||
|
||||
assert os.path.isdir(self.opt.load_weights_folder), \
|
||||
"Cannot find folder {}".format(self.opt.load_weights_folder)
|
||||
print("loading model from folder {}".format(self.opt.load_weights_folder))
|
||||
|
||||
for n in self.opt.models_to_load:
|
||||
print("Loading {} weights...".format(n))
|
||||
path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n))
|
||||
model_dict = self.models[n].state_dict()
|
||||
pretrained_dict = torch.load(path)
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.models[n].load_state_dict(model_dict)
|
||||
|
||||
# loading adam state
|
||||
optimizer_load_path = os.path.join(self.opt.load_weights_folder, "adam.pth")
|
||||
if os.path.isfile(optimizer_load_path):
|
||||
print("Loading Adam weights")
|
||||
optimizer_dict = torch.load(optimizer_load_path)
|
||||
self.model_optimizer.load_state_dict(optimizer_dict)
|
||||
else:
|
||||
print("Cannot find Adam weights so Adam is randomly initialized")
|
@ -1,114 +0,0 @@
|
||||
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
||||
#
|
||||
# This software is licensed under the terms of the Monodepth2 licence
|
||||
# which allows for non-commercial use only, the full terms of which are made
|
||||
# available in the LICENSE file.
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import os
|
||||
import hashlib
|
||||
import zipfile
|
||||
from six.moves import urllib
|
||||
|
||||
|
||||
def readlines(filename):
|
||||
"""Read all the lines in a text file and return as a list
|
||||
"""
|
||||
with open(filename, 'r') as f:
|
||||
lines = f.read().splitlines()
|
||||
return lines
|
||||
|
||||
|
||||
def normalize_image(x):
|
||||
"""Rescale image pixels to span range [0, 1]
|
||||
"""
|
||||
ma = float(x.max().cpu().data)
|
||||
mi = float(x.min().cpu().data)
|
||||
d = ma - mi if ma != mi else 1e5
|
||||
return (x - mi) / d
|
||||
|
||||
|
||||
def sec_to_hm(t):
|
||||
"""Convert time in seconds to time in hours, minutes and seconds
|
||||
e.g. 10239 -> (2, 50, 39)
|
||||
"""
|
||||
t = int(t)
|
||||
s = t % 60
|
||||
t //= 60
|
||||
m = t % 60
|
||||
t //= 60
|
||||
return t, m, s
|
||||
|
||||
|
||||
def sec_to_hm_str(t):
|
||||
"""Convert time in seconds to a nice string
|
||||
e.g. 10239 -> '02h50m39s'
|
||||
"""
|
||||
h, m, s = sec_to_hm(t)
|
||||
return "{:02d}h{:02d}m{:02d}s".format(h, m, s)
|
||||
|
||||
|
||||
def download_model_if_doesnt_exist(model_name):
|
||||
"""If pretrained kitti model doesn't exist, download and unzip it
|
||||
"""
|
||||
# values are tuples of (<google cloud URL>, <md5 checksum>)
|
||||
download_paths = {
|
||||
"mono_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip",
|
||||
"a964b8356e08a02d009609d9e3928f7c"),
|
||||
"stereo_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip",
|
||||
"3dfb76bcff0786e4ec07ac00f658dd07"),
|
||||
"mono+stereo_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip",
|
||||
"c024d69012485ed05d7eaa9617a96b81"),
|
||||
"mono_no_pt_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip",
|
||||
"9c2f071e35027c895a4728358ffc913a"),
|
||||
"stereo_no_pt_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip",
|
||||
"41ec2de112905f85541ac33a854742d1"),
|
||||
"mono+stereo_no_pt_640x192":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip",
|
||||
"46c3b824f541d143a45c37df65fbab0a"),
|
||||
"mono_1024x320":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip",
|
||||
"0ab0766efdfeea89a0d9ea8ba90e1e63"),
|
||||
"stereo_1024x320":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip",
|
||||
"afc2f2126d70cf3fdf26b550898b501a"),
|
||||
"mono+stereo_1024x320":
|
||||
("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip",
|
||||
"cdc5fc9b23513c07d5b19235d9ef08f7"),
|
||||
}
|
||||
|
||||
if not os.path.exists("models"):
|
||||
os.makedirs("models")
|
||||
|
||||
model_path = os.path.join("models", model_name)
|
||||
|
||||
def check_file_matches_md5(checksum, fpath):
|
||||
if not os.path.exists(fpath):
|
||||
return False
|
||||
with open(fpath, 'rb') as f:
|
||||
current_md5checksum = hashlib.md5(f.read()).hexdigest()
|
||||
return current_md5checksum == checksum
|
||||
|
||||
# see if we have the model already downloaded...
|
||||
if not os.path.exists(os.path.join(model_path, "encoder.pth")):
|
||||
|
||||
model_url, required_md5checksum = download_paths[model_name]
|
||||
|
||||
if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
|
||||
print("-> Downloading pretrained model to {}".format(model_path + ".zip"))
|
||||
urllib.request.urlretrieve(model_url, model_path + ".zip")
|
||||
|
||||
if not check_file_matches_md5(required_md5checksum, model_path + ".zip"):
|
||||
print(" Failed to download a file which matches the checksum - quitting")
|
||||
quit()
|
||||
|
||||
print(" Unzipping model...")
|
||||
with zipfile.ZipFile(model_path + ".zip", 'r') as f:
|
||||
f.extractall(model_path)
|
||||
|
||||
print(" Model unzipped to {}".format(model_path))
|
Binary file not shown.
@ -1,14 +1,13 @@
|
||||
unzip weights.zip
|
||||
mkdir -p CelebAMask-HQ/MaskGAN_demo/checkpoints/label2face_512p
|
||||
mkdir -p monodepth2/models/mono+stereo_640x192
|
||||
mkdir -p pytorch-SRResNet/model
|
||||
mkdir deeplabv3
|
||||
|
||||
mv weights/colorize/* neural-colorization/
|
||||
mv weights/colorize/* ideepcolor/models/pytorch/
|
||||
mv weights/deblur/* DeblurGANv2/
|
||||
mv weights/deeplabv3/* deeplabv3
|
||||
mv weights/facegen/* CelebAMask-HQ/MaskGAN_demo/checkpoints/label2face_512p/
|
||||
mv weights/faceparse/* face-parsing.PyTorch/
|
||||
mv weights/monodepth/* monodepth2/models/mono+stereo_640x192/
|
||||
mv weights/MiDaS/* MiDaS/
|
||||
mv weights/super_resolution/* pytorch-SRResNet/model/
|
||||
rm -rf weights/
|
||||
|
@ -1,23 +0,0 @@
|
||||
Copyright (c) 2016, Richard Zhang, Phillip Isola, Alexei A. Efros
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
Binary file not shown.
@ -1,42 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import argparse
|
||||
image_extensions = {'.jpg', '.jpeg', '.JPG', '.JPEG'}
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Put all places 365 images in single folder.")
|
||||
parser.add_argument("-i",
|
||||
"--input_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="input folder: the folder containing unzipped places 365 files")
|
||||
parser.add_argument("-o",
|
||||
"--output_dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="output folder: the folder to put all images")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def genlist(image_dir):
|
||||
image_list = []
|
||||
for filename in os.listdir(image_dir):
|
||||
path = os.path.join(image_dir,filename)
|
||||
if os.path.isdir(path):
|
||||
image_list = image_list + genlist(path)
|
||||
else:
|
||||
ext = os.path.splitext(filename)[1]
|
||||
if ext in image_extensions:
|
||||
image_list.append(os.path.join(image_dir, filename))
|
||||
return image_list
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
flist = genlist(args.input_dir)
|
||||
for i,p in enumerate(flist):
|
||||
if os.path.getsize(p) != 0:
|
||||
os.rename(p,os.path.join(args.output_dir,str(i)+'.jpg'))
|
||||
shutil.rmtree(args.input_dir)
|
||||
print('done')
|
@ -1,73 +0,0 @@
|
||||
import torch
|
||||
from model import generator
|
||||
from torch.autograd import Variable
|
||||
from scipy.ndimage import zoom
|
||||
import cv2
|
||||
import os
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import numpy as np
|
||||
from skimage.color import rgb2yuv,yuv2rgb
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Colorize images")
|
||||
parser.add_argument("-i",
|
||||
"--input",
|
||||
type=str,
|
||||
required=True,
|
||||
help="input image/input dir")
|
||||
parser.add_argument("-o",
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="output image/output dir")
|
||||
parser.add_argument("-m",
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="location for model (Generator)")
|
||||
parser.add_argument("--gpu",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="which GPU to use? [-1 for cpu]")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
G = generator()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# args.gpu>=0:
|
||||
G=G.cuda(args.gpu)
|
||||
G.load_state_dict(torch.load(args.model))
|
||||
else:
|
||||
G.load_state_dict(torch.load(args.model,map_location=torch.device('cpu')))
|
||||
|
||||
def inference(G,in_path,out_path):
|
||||
p=Image.open(in_path).convert('RGB')
|
||||
img_yuv = rgb2yuv(p)
|
||||
H,W,_ = img_yuv.shape
|
||||
infimg = np.expand_dims(np.expand_dims(img_yuv[...,0], axis=0), axis=0)
|
||||
img_variable = Variable(torch.Tensor(infimg-0.5))
|
||||
if args.gpu>=0:
|
||||
img_variable=img_variable.cuda(args.gpu)
|
||||
res = G(img_variable)
|
||||
uv=res.cpu().detach().numpy()
|
||||
uv[:,0,:,:] *= 0.436
|
||||
uv[:,1,:,:] *= 0.615
|
||||
(_,_,H1,W1) = uv.shape
|
||||
uv = zoom(uv,(1,1,H/H1,W/W1))
|
||||
yuv = np.concatenate([infimg,uv],axis=1)[0]
|
||||
rgb=yuv2rgb(yuv.transpose(1,2,0))
|
||||
cv2.imwrite(out_path,(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
|
||||
|
||||
|
||||
if not os.path.isdir(args.input):
|
||||
inference(G,args.input,args.output)
|
||||
else:
|
||||
if not os.path.exists(args.output):
|
||||
os.makedirs(args.output)
|
||||
for f in os.listdir(args.input):
|
||||
inference(G,os.path.join(args.input,f),os.path.join(args.output,f))
|
||||
|
@ -1,123 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import reduce
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
class shave_block(nn.Module):
|
||||
def __init__(self, s):
|
||||
super(shave_block, self).__init__()
|
||||
self.s=s
|
||||
def forward(self,x):
|
||||
return x[:,:,self.s:-self.s,self.s:-self.s]
|
||||
|
||||
class LambdaBase(nn.Sequential):
|
||||
def __init__(self, fn, *args):
|
||||
super(LambdaBase, self).__init__(*args)
|
||||
self.lambda_func = fn
|
||||
|
||||
def forward_prepare(self, input):
|
||||
output = []
|
||||
for module in self._modules.values():
|
||||
output.append(module(input))
|
||||
return output if output else input
|
||||
|
||||
class Lambda(LambdaBase):
|
||||
def forward(self, input):
|
||||
return self.lambda_func(self.forward_prepare(input))
|
||||
|
||||
class LambdaMap(LambdaBase):
|
||||
def forward(self, input):
|
||||
return list(map(self.lambda_func,self.forward_prepare(input)))
|
||||
|
||||
class LambdaReduce(LambdaBase):
|
||||
def forward(self, input):
|
||||
return reduce(self.lambda_func,self.forward_prepare(input))
|
||||
|
||||
def generator():
|
||||
G = nn.Sequential( # Sequential,
|
||||
nn.ReflectionPad2d((40, 40, 40, 40)),
|
||||
nn.Conv2d(1,32,(9, 9),(1, 1),(4, 4)),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32,64,(3, 3),(2, 2),(1, 1)),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64,128,(3, 3),(2, 2),(1, 1)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
),
|
||||
shave_block(2),
|
||||
),
|
||||
LambdaReduce(lambda x,y: x+y), # CAddTable,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
),
|
||||
shave_block(2),
|
||||
),
|
||||
LambdaReduce(lambda x,y: x+y), # CAddTable,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
),
|
||||
shave_block(2),
|
||||
),
|
||||
LambdaReduce(lambda x,y: x+y), # CAddTable,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
),
|
||||
shave_block(2),
|
||||
),
|
||||
LambdaReduce(lambda x,y: x+y), # CAddTable,
|
||||
),
|
||||
nn.Sequential( # Sequential,
|
||||
LambdaMap(lambda x: x, # ConcatTable,
|
||||
nn.Sequential( # Sequential,
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128,128,(3, 3)),
|
||||
nn.BatchNorm2d(128),
|
||||
),
|
||||
shave_block(2),
|
||||
),
|
||||
LambdaReduce(lambda x,y: x+y), # CAddTable,
|
||||
),
|
||||
nn.ConvTranspose2d(128,64,(3, 3),(2, 2),(1, 1),(1, 1)),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(64,32,(3, 3),(2, 2),(1, 1),(1, 1)),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32,2,(9, 9),(1, 1),(4, 4)),
|
||||
nn.Tanh(),
|
||||
)
|
||||
return G
|
Binary file not shown.
@ -1,35 +0,0 @@
|
||||
from multiprocessing import Pool
|
||||
from PIL import Image
|
||||
import os
|
||||
import argparse
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Resize all colorful imgs to 256*256 for training")
|
||||
parser.add_argument("-d",
|
||||
"--dir",
|
||||
required=True,
|
||||
type=str,
|
||||
help="The directory includes all jpg images")
|
||||
parser.add_argument("-n",
|
||||
"--nprocesses",
|
||||
default=10,
|
||||
type=int,
|
||||
help="Using how many processes")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def doit(x):
|
||||
a=Image.open(x)
|
||||
if a.getbands()!=('R','G','B'):
|
||||
os.remove(x)
|
||||
return
|
||||
a.resize((256,256),Image.BICUBIC).save(x)
|
||||
return
|
||||
|
||||
args=parse_args()
|
||||
pool = Pool(processes=args.nprocesses)
|
||||
jpgs = []
|
||||
flist = os.listdir(args.dir)
|
||||
full_flist = [os.path.join(args.dir,x) for x in flist]
|
||||
pool.map(doit, full_flist)
|
||||
print('done')
|
@ -1,186 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import argparse
|
||||
from torch.autograd import Variable
|
||||
import torchvision.models as models
|
||||
import os
|
||||
from torch.utils import data
|
||||
from model import generator
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from skimage.color import rgb2yuv,yuv2rgb
|
||||
import cv2
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train a GAN based model")
|
||||
parser.add_argument("-d",
|
||||
"--training_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Training directory (folder contains all 256*256 images)")
|
||||
parser.add_argument("-t",
|
||||
"--test_image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Test image location")
|
||||
parser.add_argument("-c",
|
||||
"--checkpoint_location",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Place to save checkpoints")
|
||||
parser.add_argument("-e",
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Epoches to run training")
|
||||
parser.add_argument("--gpu",
|
||||
type=int,
|
||||
default=0,
|
||||
help="which GPU to use?")
|
||||
parser.add_argument("-b",
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=20,
|
||||
help="batch size")
|
||||
parser.add_argument("-w",
|
||||
"--num_workers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="Number of workers to fetch data")
|
||||
parser.add_argument("-p",
|
||||
"--pixel_loss_weights",
|
||||
type=float,
|
||||
default=1000.0,
|
||||
help="Pixel-wise loss weights")
|
||||
parser.add_argument("--g_every",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Training generator every k iteration")
|
||||
parser.add_argument("--g_lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate for generator")
|
||||
parser.add_argument("--d_lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate for discriminator")
|
||||
parser.add_argument("-i",
|
||||
"--checkpoint_every",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Save checkpoint every k iteration (checkpoints for same epoch will overwrite)")
|
||||
parser.add_argument("--d_init",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Init weights for discriminator")
|
||||
parser.add_argument("--g_init",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Init weights for generator")
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
# define data generator
|
||||
class img_data(data.Dataset):
|
||||
def __init__(self, path):
|
||||
files = os.listdir(path)
|
||||
self.files = [os.path.join(path,x) for x in files]
|
||||
def __len__(self):
|
||||
return len(self.files)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.open(self.files[index])
|
||||
yuv = rgb2yuv(img)
|
||||
y = yuv[...,0]-0.5
|
||||
u_t = yuv[...,1] / 0.43601035
|
||||
v_t = yuv[...,2] / 0.61497538
|
||||
return torch.Tensor(np.expand_dims(y,axis=0)),torch.Tensor(np.stack([u_t,v_t],axis=0))
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if not os.path.exists(os.path.join(args.checkpoint_location,'weights')):
|
||||
os.makedirs(os.path.join(args.checkpoint_location,'weights'))
|
||||
|
||||
# Define G, same as torch version
|
||||
G = generator().cuda(args.gpu)
|
||||
|
||||
# define D
|
||||
D = models.resnet18(pretrained=False,num_classes=2)
|
||||
D.fc = nn.Sequential(nn.Linear(512, 1), nn.Sigmoid())
|
||||
D = D.cuda(args.gpu)
|
||||
|
||||
trainset = img_data(args.training_dir)
|
||||
params = {'batch_size': args.batch_size,
|
||||
'shuffle': True,
|
||||
'num_workers': args.num_workers}
|
||||
training_generator = data.DataLoader(trainset, **params)
|
||||
if args.test_image is not None:
|
||||
test_img = Image.open(args.test_image).convert('RGB').resize((256,256))
|
||||
test_yuv = rgb2yuv(test_img)
|
||||
test_inf = test_yuv[...,0].reshape(1,1,256,256)
|
||||
test_var = Variable(torch.Tensor(test_inf-0.5)).cuda(args.gpu)
|
||||
if args.d_init is not None:
|
||||
D.load_state_dict(torch.load(args.d_init))
|
||||
if args.g_init is not None:
|
||||
G.load_state_dict(torch.load(args.g_init))
|
||||
|
||||
# save test image for beginning
|
||||
if args.test_image is not None:
|
||||
test_res = G(test_var)
|
||||
uv=test_res.cpu().detach().numpy()
|
||||
uv[:,0,:,:] *= 0.436
|
||||
uv[:,1,:,:] *= 0.615
|
||||
test_yuv = np.concatenate([test_inf,uv],axis=1).reshape(3,256,256)
|
||||
test_rgb = yuv2rgb(test_yuv.transpose(1,2,0))
|
||||
cv2.imwrite(os.path.join(args.checkpoint_location,'test_init.jpg'),(test_rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
|
||||
|
||||
i=0
|
||||
adversarial_loss = torch.nn.BCELoss()
|
||||
optimizer_G = torch.optim.Adam(G.parameters(), lr=args.g_lr, betas=(0.5, 0.999))
|
||||
optimizer_D = torch.optim.Adam(D.parameters(), lr=args.d_lr, betas=(0.5, 0.999))
|
||||
for epoch in range(args.epoch):
|
||||
for y, uv in training_generator:
|
||||
# Adversarial ground truths
|
||||
valid = Variable(torch.Tensor(y.size(0), 1).fill_(1.0), requires_grad=False).cuda(args.gpu)
|
||||
fake = Variable(torch.Tensor(y.size(0), 1).fill_(0.0), requires_grad=False).cuda(args.gpu)
|
||||
|
||||
yvar = Variable(y).cuda(args.gpu)
|
||||
uvvar = Variable(uv).cuda(args.gpu)
|
||||
real_imgs = torch.cat([yvar,uvvar],dim=1)
|
||||
|
||||
optimizer_G.zero_grad()
|
||||
uvgen = G(yvar)
|
||||
# Generate a batch of images
|
||||
gen_imgs = torch.cat([yvar.detach(),uvgen],dim=1)
|
||||
|
||||
# Loss measures generator's ability to fool the discriminator
|
||||
g_loss_gan = adversarial_loss(D(gen_imgs), valid)
|
||||
g_loss = g_loss_gan + args.pixel_loss_weights * torch.mean((uvvar-uvgen)**2)
|
||||
if i%args.g_every==0:
|
||||
g_loss.backward()
|
||||
optimizer_G.step()
|
||||
|
||||
optimizer_D.zero_grad()
|
||||
|
||||
# Measure discriminator's ability to classify real from generated samples
|
||||
real_loss = adversarial_loss(D(real_imgs), valid)
|
||||
fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
|
||||
d_loss = (real_loss + fake_loss) / 2
|
||||
d_loss.backward()
|
||||
optimizer_D.step()
|
||||
i+=1
|
||||
if i%args.checkpoint_every==0:
|
||||
print ("Epoch: %d: [D loss: %f] [G total loss: %f] [G GAN Loss: %f]" % (epoch, d_loss.item(), g_loss.item(), g_loss_gan.item()))
|
||||
|
||||
torch.save(D.state_dict(), os.path.join(args.checkpoint_location,'weights','D'+str(epoch)+'.pth'))
|
||||
torch.save(G.state_dict(), os.path.join(args.checkpoint_location,'weights','G'+str(epoch)+'.pth'))
|
||||
if args.test_image is not None:
|
||||
test_res = G(test_var)
|
||||
uv=test_res.cpu().detach().numpy()
|
||||
uv[:,0,:,:] *= 0.436
|
||||
uv[:,1,:,:] *= 0.615
|
||||
test_yuv = np.concatenate([test_inf,uv],axis=1).reshape(3,256,256)
|
||||
test_rgb = yuv2rgb(test_yuv.transpose(1,2,0))
|
||||
cv2.imwrite(os.path.join(args.checkpoint_location,'test_epoch_'+str(epoch)+'.jpg'),(test_rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
|
||||
torch.save(D.state_dict(), os.path.join(args.checkpoint_location,'D_final.pth'))
|
||||
torch.save(G.state_dict(), os.path.join(args.checkpoint_location,'G_final.pth'))
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user