import torch import torch.nn as nn import numpy as np from torch.optim import AdamW import torch.optim as optim import itertools from model.warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP from model.IFNet import * import torch.nn.functional as F from model.loss import * def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), nn.PReLU(out_planes) ) def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), nn.PReLU(out_planes) ) def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), ) class ResBlock(nn.Module): def __init__(self, in_planes, out_planes, stride=2): super(ResBlock, self).__init__() if in_planes == out_planes and stride == 1: self.conv0 = nn.Identity() else: self.conv0 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False) self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.relu1 = nn.PReLU(1) self.relu2 = nn.PReLU(out_planes) self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False) def forward(self, x): y = self.conv0(x) x = self.conv1(x) x = self.conv2(x) w = x.mean(3, True).mean(2, True) w = self.relu1(self.fc1(w)) w = torch.sigmoid(self.fc2(w)) x = self.relu2(x * w + y) return x c = 16 class ContextNet(nn.Module): def __init__(self, cFlag): super(ContextNet, self).__init__() self.conv1 = ResBlock(3, c) self.conv2 = ResBlock(c, 2 * c) self.conv3 = ResBlock(2 * c, 4 * c) self.conv4 = ResBlock(4 * c, 8 * c) self.cFlag = cFlag def forward(self, x, flow): x = self.conv1(x) f1 = warp(x, flow, self.cFlag) x = self.conv2(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f2 = warp(x, flow, self.cFlag) x = self.conv3(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f3 = warp(x, flow, self.cFlag) x = self.conv4(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 f4 = warp(x, flow, self.cFlag) return [f1, f2, f3, f4] class FusionNet(nn.Module): def __init__(self, cFlag): super(FusionNet, self).__init__() self.down0 = ResBlock(8, 2 * c) self.down1 = ResBlock(4 * c, 4 * c) self.down2 = ResBlock(8 * c, 8 * c) self.down3 = ResBlock(16 * c, 16 * c) self.up0 = deconv(32 * c, 8 * c) self.up1 = deconv(16 * c, 4 * c) self.up2 = deconv(8 * c, 2 * c) self.up3 = deconv(4 * c, c) self.conv = nn.Conv2d(c, 4, 3, 1, 1) self.cFlag = cFlag def forward(self, img0, img1, flow, c0, c1, flow_gt): warped_img0 = warp(img0, flow, self.cFlag) warped_img1 = warp(img1, -flow, self.cFlag) if flow_gt == None: warped_img0_gt, warped_img1_gt = None, None else: warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img1_gt = warp(img1, flow_gt[:, 2:4]) s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1)) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) x = self.up1(torch.cat((x, s2), 1)) x = self.up2(torch.cat((x, s1), 1)) x = self.up3(torch.cat((x, s0), 1)) x = self.conv(x) return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt class Model: def __init__(self, c_flag, local_rank=-1): self.flownet = IFNet(c_flag) self.contextnet = ContextNet(c_flag) self.fusionnet = FusionNet(c_flag) self.device(c_flag) self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR( self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() self.ter = Ternary(c_flag) self.sobel = SOBEL(c_flag) if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[ local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[ local_rank], output_device=local_rank) self.fusionnet = DDP(self.fusionnet, device_ids=[ local_rank], output_device=local_rank) def train(self): self.flownet.train() self.contextnet.train() self.fusionnet.train() def eval(self): self.flownet.eval() self.contextnet.eval() self.fusionnet.eval() def device(self, c_flag): if torch.cuda.is_available() and not c_flag: device = torch.device("cuda") else: device = torch.device("cpu") self.flownet.to(device) self.contextnet.to(device) self.fusionnet.to(device) def load_model(self, path, rank=0): def convert(param): return { k.replace("module.", ""): v for k, v in param.items() if "module." in k } if rank == 0: self.flownet.load_state_dict( convert(torch.load('{}/flownet.pkl'.format(path), map_location=torch.device("cpu")))) self.contextnet.load_state_dict( convert(torch.load('{}/contextnet.pkl'.format(path), map_location=torch.device("cpu")))) self.fusionnet.load_state_dict( convert(torch.load('{}/unet.pkl'.format(path), map_location=torch.device("cpu")))) def save_model(self, path, rank=0): if rank == 0: torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) def predict(self, imgs, flow, training=True, flow_gt=None): img0 = imgs[:, :3] img1 = imgs[:, 3:] c0 = self.contextnet(img0, flow) c1 = self.contextnet(img1, -flow) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0 refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( img0, img1, flow, c0, c1, flow_gt) res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 mask = torch.sigmoid(refine_output[:, 3:4]) merged_img = warped_img0 * mask + warped_img1 * (1 - mask) pred = merged_img + res pred = torch.clamp(pred, 0, 1) if training: return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt else: return pred def inference(self, img0, img1): imgs = torch.cat((img0, img1), 1) flow, _ = self.flownet(imgs) return self.predict(imgs, flow, training=False).detach() def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate if training: self.train() else: self.eval() flow, flow_list = self.flownet(imgs) pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( imgs, flow, flow_gt=flow_gt) loss_ter = self.ter(pred, gt).mean() if training: with torch.no_grad(): loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_mask = torch.abs( merged_img - gt).sum(1, True).float().detach() loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", align_corners=False).detach() flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5).detach() loss_cons = 0 for i in range(3): loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons = loss_cons.mean() * 0.01 else: loss_cons = torch.tensor([0]) loss_flow = torch.abs(warped_img0 - gt).mean() loss_mask = 1 loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean() if training: self.optimG.zero_grad() loss_G = loss_l1 + loss_cons + loss_ter loss_G.backward() self.optimG.step() return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask if __name__ == '__main__': img0 = torch.zeros(3, 3, 256, 256).float().to(device) img1 = torch.tensor(np.random.normal( 0, 1, (3, 3, 256, 256))).float().to(device) imgs = torch.cat((img0, img1), 1) model = Model() model.eval() print(model.inference(imgs).shape)