mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
d5a276584b
Fixes conda package. Fixes #317
40 lines
1.2 KiB
Python
40 lines
1.2 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
|
|
# Upsample + BatchNorm
|
|
class UpSampleBN(nn.Module):
|
|
def __init__(self, skip_input, output_features):
|
|
super().__init__()
|
|
|
|
self._net = nn.Sequential(
|
|
nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
|
nn.BatchNorm2d(output_features),
|
|
nn.LeakyReLU(),
|
|
nn.Conv2d(
|
|
output_features, output_features, kernel_size=3, stride=1, padding=1
|
|
),
|
|
nn.BatchNorm2d(output_features),
|
|
nn.LeakyReLU(),
|
|
)
|
|
|
|
def forward(self, x, concat_with):
|
|
up_x = F.interpolate(
|
|
x,
|
|
size=[concat_with.size(2), concat_with.size(3)],
|
|
mode="bilinear",
|
|
align_corners=True,
|
|
)
|
|
f = torch.cat([up_x, concat_with], dim=1)
|
|
return self._net(f)
|
|
|
|
|
|
def norm_normalize(norm_out):
|
|
min_kappa = 0.01
|
|
norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
|
|
norm = torch.sqrt(norm_x**2.0 + norm_y**2.0 + norm_z**2.0) + 1e-10
|
|
kappa = F.elu(kappa) + 1.0 + min_kappa
|
|
final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
|
|
return final_out
|