perf: use silu instead of nonlinearity for speedup

pull/256/head
Bryce 2 years ago committed by Bryce Drennan
parent 68e7fd73c5
commit b93b6a4d7c

@ -298,6 +298,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog ## ChangeLog
- perf: use Silu for performance improvement over nonlinearity
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost. - perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously. - perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
- perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU. - perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU.

@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from einops import rearrange from einops import rearrange
from torch import nn from torch import nn
from torch.nn.functional import silu
from imaginairy.modules.attention import MemoryEfficientCrossAttention from imaginairy.modules.attention import MemoryEfficientCrossAttention
from imaginairy.utils import get_device from imaginairy.utils import get_device
@ -46,18 +47,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb return emb
# context manager that prints execution time
def nonlinearity(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm( return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
@ -142,19 +131,19 @@ class ResnetBlock(nn.Module):
h2 = self.norm1(h1) h2 = self.norm1(h1)
del h1 del h1
h3 = nonlinearity(h2) h3 = silu(h2)
del h2 del h2
h4 = self.conv1(h3) h4 = self.conv1(h3)
del h3 del h3
if temb is not None: if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None] h4 = h4 + self.temb_proj(silu(temb))[:, :, None, None]
h5 = self.norm2(h4) h5 = self.norm2(h4)
del h4 del h4
h6 = nonlinearity(h5) h6 = silu(h5)
del h5 del h5
h7 = self.dropout(h6) h7 = self.dropout(h6)
@ -514,7 +503,7 @@ class Model(nn.Module):
assert t is not None assert t is not None
temb = get_timestep_embedding(t, self.ch) temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb) temb = self.temb.dense[0](temb)
temb = nonlinearity(temb) temb = silu(temb)
temb = self.temb.dense[1](temb) temb = self.temb.dense[1](temb)
else: else:
temb = None temb = None
@ -549,7 +538,7 @@ class Model(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -669,7 +658,7 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -815,7 +804,7 @@ class Decoder(nn.Module):
h1 = self.norm_out(h) h1 = self.norm_out(h)
del h del h
h2 = nonlinearity(h1) h2 = silu(h1)
del h1 del h1
h = self.conv_out(h2) h = self.conv_out(h2)
@ -870,7 +859,7 @@ class SimpleDecoder(nn.Module):
x = layer(x) x = layer(x)
h = self.norm_out(x) h = self.norm_out(x)
h = nonlinearity(h) h = silu(h)
x = self.conv_out(h) x = self.conv_out(h)
return x return x
@ -928,7 +917,7 @@ class UpsampleDecoder(nn.Module):
if i_level != self.num_resolutions - 1: if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h) h = self.upsample_blocks[k](h)
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h

@ -245,12 +245,6 @@ def normalization(channels):
return GroupNorm32(32, channels) return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm): class GroupNorm32(nn.GroupNorm):
def forward(self, x): # noqa def forward(self, x): # noqa
return super().forward(x.float()).type(x.dtype) return super().forward(x.float()).type(x.dtype)

@ -2,7 +2,6 @@ import time
import torch import torch
from imaginairy.modules.diffusion.model import nonlinearity
from imaginairy.utils import get_device from imaginairy.utils import get_device
@ -22,24 +21,26 @@ class Timer:
def test_nonlinearity(): def test_nonlinearity():
from torch.nn.functional import silu
# mps before changes: 1021.54ms # mps before changes: 1021.54ms
with Timer("nonlinearity"): with Timer("nonlinearity"):
for _ in range(10): for _ in range(10):
for _ in range(11): for _ in range(11):
t = torch.randn(1, 512, 64, 64, device=get_device()) t = torch.randn(1, 512, 64, 64, device=get_device())
nonlinearity(t) silu(t)
for _ in range(7): for _ in range(7):
t = torch.randn(1, 512, 128, 128, device=get_device()) t = torch.randn(1, 512, 128, 128, device=get_device())
nonlinearity(t) silu(t)
for _ in range(1): for _ in range(1):
t = torch.randn(1, 512, 256, 256, device=get_device()) t = torch.randn(1, 512, 256, 256, device=get_device())
nonlinearity(t) silu(t)
for _ in range(5): for _ in range(5):
t = torch.randn(1, 256, 256, 256, device=get_device()) t = torch.randn(1, 256, 256, 256, device=get_device())
nonlinearity(t) silu(t)
for _ in range(1): for _ in range(1):
t = torch.randn(1, 256, 512, 512, device=get_device()) t = torch.randn(1, 256, 512, 512, device=get_device())
nonlinearity(t) silu(t)
for _ in range(6): for _ in range(6):
t = torch.randn(1, 128, 512, 512, device=get_device()) t = torch.randn(1, 128, 512, 512, device=get_device())
nonlinearity(t) silu(t)

Loading…
Cancel
Save