You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
4.0 KiB
Python

import torch
import torch.nn as nn
from functools import reduce
from torch.autograd import Variable
class shave_block(nn.Module):
def __init__(self, s):
super(shave_block, self).__init__()
self.s=s
def forward(self,x):
return x[:,:,self.s:-self.s,self.s:-self.s]
class LambdaBase(nn.Sequential):
def __init__(self, fn, *args):
super(LambdaBase, self).__init__(*args)
self.lambda_func = fn
def forward_prepare(self, input):
output = []
for module in self._modules.values():
output.append(module(input))
return output if output else input
class Lambda(LambdaBase):
def forward(self, input):
return self.lambda_func(self.forward_prepare(input))
class LambdaMap(LambdaBase):
def forward(self, input):
return list(map(self.lambda_func,self.forward_prepare(input)))
class LambdaReduce(LambdaBase):
def forward(self, input):
return reduce(self.lambda_func,self.forward_prepare(input))
def generator():
G = nn.Sequential( # Sequential,
nn.ReflectionPad2d((40, 40, 40, 40)),
nn.Conv2d(1,32,(9, 9),(1, 1),(4, 4)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,64,(3, 3),(2, 2),(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,128,(3, 3),(2, 2),(1, 1)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
),
shave_block(2),
),
LambdaReduce(lambda x,y: x+y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
),
shave_block(2),
),
LambdaReduce(lambda x,y: x+y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
),
shave_block(2),
),
LambdaReduce(lambda x,y: x+y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
),
shave_block(2),
),
LambdaReduce(lambda x,y: x+y), # CAddTable,
),
nn.Sequential( # Sequential,
LambdaMap(lambda x: x, # ConcatTable,
nn.Sequential( # Sequential,
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128,128,(3, 3)),
nn.BatchNorm2d(128),
),
shave_block(2),
),
LambdaReduce(lambda x,y: x+y), # CAddTable,
),
nn.ConvTranspose2d(128,64,(3, 3),(2, 2),(1, 1),(1, 1)),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64,32,(3, 3),(2, 2),(1, 1),(1, 1)),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32,2,(9, 9),(1, 1),(4, 4)),
nn.Tanh(),
)
return G