perf: use silu instead of nonlinearity for speedup

pull/256/head
Bryce 2 years ago committed by Bryce Drennan
parent 68e7fd73c5
commit b93b6a4d7c

@ -298,6 +298,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- perf: use Silu for performance improvement over nonlinearity
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
- perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU.

@ -7,6 +7,7 @@ import numpy as np
import torch
from einops import rearrange
from torch import nn
from torch.nn.functional import silu
from imaginairy.modules.attention import MemoryEfficientCrossAttention
from imaginairy.utils import get_device
@ -46,18 +47,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
# context manager that prints execution time
def nonlinearity(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
@ -142,19 +131,19 @@ class ResnetBlock(nn.Module):
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
h3 = silu(h2)
del h2
h4 = self.conv1(h3)
del h3
if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h4 = h4 + self.temb_proj(silu(temb))[:, :, None, None]
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
h6 = silu(h5)
del h5
h7 = self.dropout(h6)
@ -514,7 +503,7 @@ class Model(nn.Module):
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = silu(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
@ -549,7 +538,7 @@ class Model(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -669,7 +658,7 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h
@ -815,7 +804,7 @@ class Decoder(nn.Module):
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
h2 = silu(h1)
del h1
h = self.conv_out(h2)
@ -870,7 +859,7 @@ class SimpleDecoder(nn.Module):
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
h = silu(h)
x = self.conv_out(h)
return x
@ -928,7 +917,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = silu(h)
h = self.conv_out(h)
return h

@ -245,12 +245,6 @@ def normalization(channels):
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x): # noqa
return super().forward(x.float()).type(x.dtype)

@ -2,7 +2,6 @@ import time
import torch
from imaginairy.modules.diffusion.model import nonlinearity
from imaginairy.utils import get_device
@ -22,24 +21,26 @@ class Timer:
def test_nonlinearity():
from torch.nn.functional import silu
# mps before changes: 1021.54ms
with Timer("nonlinearity"):
for _ in range(10):
for _ in range(11):
t = torch.randn(1, 512, 64, 64, device=get_device())
nonlinearity(t)
silu(t)
for _ in range(7):
t = torch.randn(1, 512, 128, 128, device=get_device())
nonlinearity(t)
silu(t)
for _ in range(1):
t = torch.randn(1, 512, 256, 256, device=get_device())
nonlinearity(t)
silu(t)
for _ in range(5):
t = torch.randn(1, 256, 256, 256, device=get_device())
nonlinearity(t)
silu(t)
for _ in range(1):
t = torch.randn(1, 256, 512, 512, device=get_device())
nonlinearity(t)
silu(t)
for _ in range(6):
t = torch.randn(1, 128, 512, 512, device=get_device())
nonlinearity(t)
silu(t)

Loading…
Cancel
Save