mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
perf: use silu instead of nonlinearity for speedup
This commit is contained in:
parent
68e7fd73c5
commit
b93b6a4d7c
@ -298,6 +298,7 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
|
|
||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
|
- perf: use Silu for performance improvement over nonlinearity
|
||||||
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
|
- perf: `xformers` added as a dependency for linux and windows. Gives a nice speed boost.
|
||||||
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
|
- perf: sliced attention now runs on MacOS. A typo prevented that from happening previously.
|
||||||
- perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU.
|
- perf: sliced latent decoding - now possible to make much bigger images. 3310x3310 on 11 GB GPU.
|
||||||
|
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
from imaginairy.modules.attention import MemoryEfficientCrossAttention
|
from imaginairy.modules.attention import MemoryEfficientCrossAttention
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
@ -46,18 +47,6 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
# context manager that prints execution time
|
|
||||||
|
|
||||||
|
|
||||||
def nonlinearity(x):
|
|
||||||
# swish
|
|
||||||
t = torch.sigmoid(x)
|
|
||||||
x *= t
|
|
||||||
del t
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, num_groups=32):
|
def Normalize(in_channels, num_groups=32):
|
||||||
return torch.nn.GroupNorm(
|
return torch.nn.GroupNorm(
|
||||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
||||||
@ -142,19 +131,19 @@ class ResnetBlock(nn.Module):
|
|||||||
h2 = self.norm1(h1)
|
h2 = self.norm1(h1)
|
||||||
del h1
|
del h1
|
||||||
|
|
||||||
h3 = nonlinearity(h2)
|
h3 = silu(h2)
|
||||||
del h2
|
del h2
|
||||||
|
|
||||||
h4 = self.conv1(h3)
|
h4 = self.conv1(h3)
|
||||||
del h3
|
del h3
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
h4 = h4 + self.temb_proj(silu(temb))[:, :, None, None]
|
||||||
|
|
||||||
h5 = self.norm2(h4)
|
h5 = self.norm2(h4)
|
||||||
del h4
|
del h4
|
||||||
|
|
||||||
h6 = nonlinearity(h5)
|
h6 = silu(h5)
|
||||||
del h5
|
del h5
|
||||||
|
|
||||||
h7 = self.dropout(h6)
|
h7 = self.dropout(h6)
|
||||||
@ -514,7 +503,7 @@ class Model(nn.Module):
|
|||||||
assert t is not None
|
assert t is not None
|
||||||
temb = get_timestep_embedding(t, self.ch)
|
temb = get_timestep_embedding(t, self.ch)
|
||||||
temb = self.temb.dense[0](temb)
|
temb = self.temb.dense[0](temb)
|
||||||
temb = nonlinearity(temb)
|
temb = silu(temb)
|
||||||
temb = self.temb.dense[1](temb)
|
temb = self.temb.dense[1](temb)
|
||||||
else:
|
else:
|
||||||
temb = None
|
temb = None
|
||||||
@ -549,7 +538,7 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = silu(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
@ -669,7 +658,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = silu(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
@ -815,7 +804,7 @@ class Decoder(nn.Module):
|
|||||||
h1 = self.norm_out(h)
|
h1 = self.norm_out(h)
|
||||||
del h
|
del h
|
||||||
|
|
||||||
h2 = nonlinearity(h1)
|
h2 = silu(h1)
|
||||||
del h1
|
del h1
|
||||||
|
|
||||||
h = self.conv_out(h2)
|
h = self.conv_out(h2)
|
||||||
@ -870,7 +859,7 @@ class SimpleDecoder(nn.Module):
|
|||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
h = self.norm_out(x)
|
h = self.norm_out(x)
|
||||||
h = nonlinearity(h)
|
h = silu(h)
|
||||||
x = self.conv_out(h)
|
x = self.conv_out(h)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -928,7 +917,7 @@ class UpsampleDecoder(nn.Module):
|
|||||||
if i_level != self.num_resolutions - 1:
|
if i_level != self.num_resolutions - 1:
|
||||||
h = self.upsample_blocks[k](h)
|
h = self.upsample_blocks[k](h)
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = nonlinearity(h)
|
h = silu(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
@ -245,12 +245,6 @@ def normalization(channels):
|
|||||||
return GroupNorm32(32, channels)
|
return GroupNorm32(32, channels)
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
|
||||||
class SiLU(nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupNorm32(nn.GroupNorm):
|
class GroupNorm32(nn.GroupNorm):
|
||||||
def forward(self, x): # noqa
|
def forward(self, x): # noqa
|
||||||
return super().forward(x.float()).type(x.dtype)
|
return super().forward(x.float()).type(x.dtype)
|
||||||
|
@ -2,7 +2,6 @@ import time
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from imaginairy.modules.diffusion.model import nonlinearity
|
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
|
|
||||||
@ -22,24 +21,26 @@ class Timer:
|
|||||||
|
|
||||||
|
|
||||||
def test_nonlinearity():
|
def test_nonlinearity():
|
||||||
|
from torch.nn.functional import silu
|
||||||
|
|
||||||
# mps before changes: 1021.54ms
|
# mps before changes: 1021.54ms
|
||||||
with Timer("nonlinearity"):
|
with Timer("nonlinearity"):
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
for _ in range(11):
|
for _ in range(11):
|
||||||
t = torch.randn(1, 512, 64, 64, device=get_device())
|
t = torch.randn(1, 512, 64, 64, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
for _ in range(7):
|
for _ in range(7):
|
||||||
t = torch.randn(1, 512, 128, 128, device=get_device())
|
t = torch.randn(1, 512, 128, 128, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
t = torch.randn(1, 512, 256, 256, device=get_device())
|
t = torch.randn(1, 512, 256, 256, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
t = torch.randn(1, 256, 256, 256, device=get_device())
|
t = torch.randn(1, 256, 256, 256, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
for _ in range(1):
|
for _ in range(1):
|
||||||
t = torch.randn(1, 256, 512, 512, device=get_device())
|
t = torch.randn(1, 256, 512, 512, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
for _ in range(6):
|
for _ in range(6):
|
||||||
t = torch.randn(1, 128, 512, 512, device=get_device())
|
t = torch.randn(1, 128, 512, 512, device=get_device())
|
||||||
nonlinearity(t)
|
silu(t)
|
||||||
|
Loading…
Reference in New Issue
Block a user