|
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
from torch import nn
|
|
|
|
|
from torch.nn.functional import silu
|
|
|
|
|
|
|
|
|
|
from imaginairy.modules.attention import MemoryEfficientCrossAttention
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
@ -46,18 +47,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|
|
|
|
return emb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# context manager that prints execution time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def nonlinearity(x):
|
|
|
|
|
# swish
|
|
|
|
|
t = torch.sigmoid(x)
|
|
|
|
|
x *= t
|
|
|
|
|
del t
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def Normalize(in_channels, num_groups=32):
|
|
|
|
|
return torch.nn.GroupNorm(
|
|
|
|
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
|
|
|
@ -142,19 +131,19 @@ class ResnetBlock(nn.Module):
|
|
|
|
|
h2 = self.norm1(h1)
|
|
|
|
|
del h1
|
|
|
|
|
|
|
|
|
|
h3 = nonlinearity(h2)
|
|
|
|
|
h3 = silu(h2)
|
|
|
|
|
del h2
|
|
|
|
|
|
|
|
|
|
h4 = self.conv1(h3)
|
|
|
|
|
del h3
|
|
|
|
|
|
|
|
|
|
if temb is not None:
|
|
|
|
|
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
|
|
|
|
h4 = h4 + self.temb_proj(silu(temb))[:, :, None, None]
|
|
|
|
|
|
|
|
|
|
h5 = self.norm2(h4)
|
|
|
|
|
del h4
|
|
|
|
|
|
|
|
|
|
h6 = nonlinearity(h5)
|
|
|
|
|
h6 = silu(h5)
|
|
|
|
|
del h5
|
|
|
|
|
|
|
|
|
|
h7 = self.dropout(h6)
|
|
|
|
@ -514,7 +503,7 @@ class Model(nn.Module):
|
|
|
|
|
assert t is not None
|
|
|
|
|
temb = get_timestep_embedding(t, self.ch)
|
|
|
|
|
temb = self.temb.dense[0](temb)
|
|
|
|
|
temb = nonlinearity(temb)
|
|
|
|
|
temb = silu(temb)
|
|
|
|
|
temb = self.temb.dense[1](temb)
|
|
|
|
|
else:
|
|
|
|
|
temb = None
|
|
|
|
@ -549,7 +538,7 @@ class Model(nn.Module):
|
|
|
|
|
|
|
|
|
|
# end
|
|
|
|
|
h = self.norm_out(h)
|
|
|
|
|
h = nonlinearity(h)
|
|
|
|
|
h = silu(h)
|
|
|
|
|
h = self.conv_out(h)
|
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
@ -669,7 +658,7 @@ class Encoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# end
|
|
|
|
|
h = self.norm_out(h)
|
|
|
|
|
h = nonlinearity(h)
|
|
|
|
|
h = silu(h)
|
|
|
|
|
h = self.conv_out(h)
|
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
@ -815,7 +804,7 @@ class Decoder(nn.Module):
|
|
|
|
|
h1 = self.norm_out(h)
|
|
|
|
|
del h
|
|
|
|
|
|
|
|
|
|
h2 = nonlinearity(h1)
|
|
|
|
|
h2 = silu(h1)
|
|
|
|
|
del h1
|
|
|
|
|
|
|
|
|
|
h = self.conv_out(h2)
|
|
|
|
@ -870,7 +859,7 @@ class SimpleDecoder(nn.Module):
|
|
|
|
|
x = layer(x)
|
|
|
|
|
|
|
|
|
|
h = self.norm_out(x)
|
|
|
|
|
h = nonlinearity(h)
|
|
|
|
|
h = silu(h)
|
|
|
|
|
x = self.conv_out(h)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -928,7 +917,7 @@ class UpsampleDecoder(nn.Module):
|
|
|
|
|
if i_level != self.num_resolutions - 1:
|
|
|
|
|
h = self.upsample_blocks[k](h)
|
|
|
|
|
h = self.norm_out(h)
|
|
|
|
|
h = nonlinearity(h)
|
|
|
|
|
h = silu(h)
|
|
|
|
|
h = self.conv_out(h)
|
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
|