From b93b6a4d7ccaf839d199a4a5178b0600b92e52fb Mon Sep 17 00:00:00 2001 From: Bryce Date: Wed, 15 Feb 2023 14:05:29 -0800 Subject: [PATCH] perf: use silu instead of nonlinearity for speedup --- README.md | 1 + imaginairy/modules/diffusion/model.py | 31 +++++++++------------------ imaginairy/modules/diffusion/util.py | 6 ------ tests/modules/diffusion/test_model.py | 15 +++++++------ 4 files changed, 19 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 1045841..d17b2b2 100644 --- a/README.md +++ b/README.md @@ -298,6 +298,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog +- perf: use Silu for performance improvement over nonlinearity - perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost. - perf: sliced attention now runs on MacOS. A typo prevented that from happening previously. - perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU. diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 81653d2..8bf02b8 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -7,6 +7,7 @@ import numpy as np import torch from einops import rearrange from torch import nn +from torch.nn.functional import silu from imaginairy.modules.attention import MemoryEfficientCrossAttention from imaginairy.utils import get_device @@ -46,18 +47,6 @@ def get_timestep_embedding(timesteps, embedding_dim): return emb -# context manager that prints execution time - - -def nonlinearity(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - - def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True @@ -142,19 +131,19 @@ class ResnetBlock(nn.Module): h2 = self.norm1(h1) del h1 - h3 = nonlinearity(h2) + h3 = silu(h2) del h2 h4 = self.conv1(h3) del h3 if temb is not None: - h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None] + h4 = h4 + self.temb_proj(silu(temb))[:, :, None, None] h5 = self.norm2(h4) del h4 - h6 = nonlinearity(h5) + h6 = silu(h5) del h5 h7 = self.dropout(h6) @@ -514,7 +503,7 @@ class Model(nn.Module): assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) + temb = silu(temb) temb = self.temb.dense[1](temb) else: temb = None @@ -549,7 +538,7 @@ class Model(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -669,7 +658,7 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h @@ -815,7 +804,7 @@ class Decoder(nn.Module): h1 = self.norm_out(h) del h - h2 = nonlinearity(h1) + h2 = silu(h1) del h1 h = self.conv_out(h2) @@ -870,7 +859,7 @@ class SimpleDecoder(nn.Module): x = layer(x) h = self.norm_out(x) - h = nonlinearity(h) + h = silu(h) x = self.conv_out(h) return x @@ -928,7 +917,7 @@ class UpsampleDecoder(nn.Module): if i_level != self.num_resolutions - 1: h = self.upsample_blocks[k](h) h = self.norm_out(h) - h = nonlinearity(h) + h = silu(h) h = self.conv_out(h) return h diff --git a/imaginairy/modules/diffusion/util.py b/imaginairy/modules/diffusion/util.py index a973540..378d7e8 100644 --- a/imaginairy/modules/diffusion/util.py +++ b/imaginairy/modules/diffusion/util.py @@ -245,12 +245,6 @@ def normalization(channels): return GroupNorm32(32, channels) -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - class GroupNorm32(nn.GroupNorm): def forward(self, x): # noqa return super().forward(x.float()).type(x.dtype) diff --git a/tests/modules/diffusion/test_model.py b/tests/modules/diffusion/test_model.py index 02462f9..0303a48 100644 --- a/tests/modules/diffusion/test_model.py +++ b/tests/modules/diffusion/test_model.py @@ -2,7 +2,6 @@ import time import torch -from imaginairy.modules.diffusion.model import nonlinearity from imaginairy.utils import get_device @@ -22,24 +21,26 @@ class Timer: def test_nonlinearity(): + from torch.nn.functional import silu + # mps before changes: 1021.54ms with Timer("nonlinearity"): for _ in range(10): for _ in range(11): t = torch.randn(1, 512, 64, 64, device=get_device()) - nonlinearity(t) + silu(t) for _ in range(7): t = torch.randn(1, 512, 128, 128, device=get_device()) - nonlinearity(t) + silu(t) for _ in range(1): t = torch.randn(1, 512, 256, 256, device=get_device()) - nonlinearity(t) + silu(t) for _ in range(5): t = torch.randn(1, 256, 256, 256, device=get_device()) - nonlinearity(t) + silu(t) for _ in range(1): t = torch.randn(1, 256, 512, 512, device=get_device()) - nonlinearity(t) + silu(t) for _ in range(6): t = torch.randn(1, 128, 512, 512, device=get_device()) - nonlinearity(t) + silu(t)