mirror of https://github.com/kritiksoman/GIMP-ML
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
5.2 KiB
Python
127 lines
5.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import math
|
|
import cv2
|
|
import torch.nn.functional as F
|
|
|
|
class VGG16(nn.Module):
|
|
def __init__(self, args):
|
|
super(VGG16, self).__init__()
|
|
self.stage = args.stage
|
|
|
|
self.conv1_1 = nn.Conv2d(4, 64, kernel_size=3,stride = 1, padding=1,bias=True)
|
|
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3,stride = 1, padding=1,bias=True)
|
|
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=True)
|
|
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=True)
|
|
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1,bias=True)
|
|
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
|
|
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True)
|
|
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=True)
|
|
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
|
|
# model released before 2019.09.09 should use kernel_size=1 & padding=0
|
|
#self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, padding=0,bias=True)
|
|
self.conv6_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True)
|
|
|
|
self.deconv6_1 = nn.Conv2d(512, 512, kernel_size=1,bias=True)
|
|
self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2,bias=True)
|
|
self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2,bias=True)
|
|
self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2,bias=True)
|
|
self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2,bias=True)
|
|
self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2,bias=True)
|
|
|
|
self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2,bias=True)
|
|
|
|
if args.stage == 2:
|
|
# for stage2 training
|
|
for p in self.parameters():
|
|
p.requires_grad=False
|
|
|
|
if self.stage == 2 or self.stage == 3:
|
|
self.refine_conv1 = nn.Conv2d(4, 64, kernel_size=3, padding=1, bias=True)
|
|
self.refine_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
|
|
self.refine_conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True)
|
|
self.refine_pred = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True)
|
|
|
|
def forward(self, x):
|
|
# Stage 1
|
|
x11 = F.relu(self.conv1_1(x))
|
|
x12 = F.relu(self.conv1_2(x11))
|
|
x1p, id1 = F.max_pool2d(x12,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
|
|
|
# Stage 2
|
|
x21 = F.relu(self.conv2_1(x1p))
|
|
x22 = F.relu(self.conv2_2(x21))
|
|
x2p, id2 = F.max_pool2d(x22,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
|
|
|
# Stage 3
|
|
x31 = F.relu(self.conv3_1(x2p))
|
|
x32 = F.relu(self.conv3_2(x31))
|
|
x33 = F.relu(self.conv3_3(x32))
|
|
x3p, id3 = F.max_pool2d(x33,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
|
|
|
# Stage 4
|
|
x41 = F.relu(self.conv4_1(x3p))
|
|
x42 = F.relu(self.conv4_2(x41))
|
|
x43 = F.relu(self.conv4_3(x42))
|
|
x4p, id4 = F.max_pool2d(x43,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
|
|
|
# Stage 5
|
|
x51 = F.relu(self.conv5_1(x4p))
|
|
x52 = F.relu(self.conv5_2(x51))
|
|
x53 = F.relu(self.conv5_3(x52))
|
|
x5p, id5 = F.max_pool2d(x53,kernel_size=(2,2), stride=(2,2),return_indices=True)
|
|
|
|
# Stage 6
|
|
x61 = F.relu(self.conv6_1(x5p))
|
|
|
|
# Stage 6d
|
|
x61d = F.relu(self.deconv6_1(x61))
|
|
|
|
# Stage 5d
|
|
x5d = F.max_unpool2d(x61d,id5, kernel_size=2, stride=2)
|
|
x51d = F.relu(self.deconv5_1(x5d))
|
|
|
|
# Stage 4d
|
|
x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
|
|
x41d = F.relu(self.deconv4_1(x4d))
|
|
|
|
# Stage 3d
|
|
x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
|
|
x31d = F.relu(self.deconv3_1(x3d))
|
|
|
|
# Stage 2d
|
|
x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
|
|
x21d = F.relu(self.deconv2_1(x2d))
|
|
|
|
# Stage 1d
|
|
x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
|
|
x12d = F.relu(self.deconv1_1(x1d))
|
|
|
|
# Should add sigmoid? github repo add so.
|
|
raw_alpha = self.deconv1(x12d)
|
|
pred_mattes = F.sigmoid(raw_alpha)
|
|
|
|
if self.stage <= 1:
|
|
return pred_mattes, 0
|
|
|
|
# Stage2 refine conv1
|
|
refine0 = torch.cat((x[:, :3, :, :], pred_mattes), 1)
|
|
refine1 = F.relu(self.refine_conv1(refine0))
|
|
refine2 = F.relu(self.refine_conv2(refine1))
|
|
refine3 = F.relu(self.refine_conv3(refine2))
|
|
# Should add sigmoid?
|
|
# sigmoid lead to refine result all converge to 0...
|
|
#pred_refine = F.sigmoid(self.refine_pred(refine3))
|
|
pred_refine = self.refine_pred(refine3)
|
|
|
|
pred_alpha = F.sigmoid(raw_alpha + pred_refine)
|
|
|
|
#print(pred_mattes.mean(), pred_alpha.mean(), pred_refine.sum())
|
|
|
|
return pred_mattes, pred_alpha
|