perf: add back memory efficiency improvements

Removed these in a rush to get SD-2.0 out.
pull/125/head
Bryce 2 years ago committed by Bryce Drennan
parent ca126b0760
commit 257752887d

@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import gc
import math
from typing import Any, Optional
@ -8,6 +9,7 @@ from einops import rearrange
from torch import nn
from imaginairy.modules.attention import MemoryEfficientCrossAttention
from imaginairy.utils import get_device
try:
import xformers # noqa
@ -42,9 +44,16 @@ def get_timestep_embedding(timesteps, embedding_dim):
return emb
# context manager that prints execution time
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
t = torch.sigmoid(x)
x *= t
del t
return x
def Normalize(in_channels, num_groups=32):
@ -127,18 +136,30 @@ class ResnetBlock(nn.Module):
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@ -146,7 +167,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h
return x + h8
class AttnBlock(nn.Module):
@ -169,6 +190,9 @@ class AttnBlock(nn.Module):
)
def forward(self, x):
if get_device() == "cuda":
return self.forward_cuda(x)
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
@ -194,6 +218,71 @@ class AttnBlock(nn.Module):
return x + h_
def forward_cuda(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h * w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h * w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c) ** (-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h * w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(
v1, w4
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
class MemoryEfficientAttnBlock(nn.Module):
"""
@ -689,31 +778,52 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h = self.conv_in(z)
h1 = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
t = h
h = self.up[i_level].attn[i_block](t)
del t
if i_level != 0:
h = self.up[i_level].upsample(h)
t = h
h = self.up[i_level].upsample(t)
del t
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
if self.tanh_out:
h = torch.tanh(h)
t = h
h = torch.tanh(t)
del t
return h

@ -0,0 +1,45 @@
import time
import torch
from imaginairy.modules.diffusion.model import nonlinearity
from imaginairy.utils import get_device
class Timer:
def __init__(self, name):
self.name = name
self.start = None
def __enter__(self):
self.start = time.perf_counter()
return self
def __exit__(self, *args):
elapsed = time.perf_counter() - self.start
print(f"{self.name} took {elapsed*1000:.2f} ms")
def test_nonlinearity():
# mps before changes: 1021.54ms
with Timer("nonlinearity"):
for _ in range(10):
for _ in range(11):
t = torch.randn(1, 512, 64, 64, device=get_device())
nonlinearity(t)
for _ in range(7):
t = torch.randn(1, 512, 128, 128, device=get_device())
nonlinearity(t)
for _ in range(1):
t = torch.randn(1, 512, 256, 256, device=get_device())
nonlinearity(t)
for _ in range(5):
t = torch.randn(1, 256, 256, 256, device=get_device())
nonlinearity(t)
for _ in range(1):
t = torch.randn(1, 256, 512, 512, device=get_device())
nonlinearity(t)
for _ in range(6):
t = torch.randn(1, 128, 512, 512, device=get_device())
nonlinearity(t)
Loading…
Cancel
Save