From 0bb5b6b34534bc0af5a9e7b0c49e9b7f5a0c006a Mon Sep 17 00:00:00 2001 From: Bryce Date: Sun, 11 Sep 2022 03:08:51 -0700 Subject: [PATCH] perf: performance optimizations from Doggettx https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements# https://www.reddit.com/r/StableDiffusion/comments/xalaws/test_update_for_less_memory_usage_and_higher/ --- README.md | 8 +- imaginairy/api.py | 22 +++- imaginairy/modules/attention.py | 63 +++++++++++ imaginairy/modules/diffusion/model.py | 150 ++++++++++++++++++++++---- 4 files changed, 217 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 28b2300..59ee0a0 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ AI imagined images. -Tested on Linux and OSX(M1). +"just works" on Linux and OSX(M1). ```bash >> pip install imaginairy @@ -68,7 +68,7 @@ OR - img2img actually does # of steps you specify # Models Used - - CLIP + - CLIP - https://openai.com/blog/clip/ - LDM - Latent Diffusion - Stable Diffusion - https://github.com/CompVis/stable-diffusion @@ -89,14 +89,16 @@ OR - Image Generation Features - upscaling - face improvements - - image describe feature + - image describe feature - https://replicate.com/methexis-inc/img2prompt - outpainting - inpainting - cross-attention control: - https://github.com/bloc97/CrossAttentionControl/blob/main/CrossAttention_Release_NoImages.ipynb + - guided generation https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1#scrollTo=UDeXQKbPTdZI - tiling - output show-work videos - image variations https://github.com/lstein/stable-diffusion/blob/main/VARIATIONS.md + - textual inversion https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb#scrollTo=50JuJUM8EG1h - zooming videos? a la disco diffusion \ No newline at end of file diff --git a/imaginairy/api.py b/imaginairy/api.py index 25efee3..dd380dd 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -8,6 +8,7 @@ from functools import lru_cache import PIL import numpy as np import torch +import torch.nn from PIL import Image from einops import rearrange from omegaconf import OmegaConf @@ -69,8 +70,22 @@ def load_img(path, max_height=512, max_width=512): return 2.0 * image - 1.0, w, h +def patch_conv(**patch): + cls = torch.nn.Conv2d + init = cls.__init__ + + def __init__(self, *args, **kwargs): + return init(self, *args, **kwargs, **patch) + + cls.__init__ = __init__ + + @lru_cache() -def load_model(): +def load_model(tile_mode=False): + if tile_mode: + # generated images are tileable + patch_conv(padding_mode="circular") + config = "configs/stable-diffusion-v1.yaml" config = OmegaConf.load(f"{LIB_PATH}/{config}") model = load_model_from_config(config) @@ -143,6 +158,7 @@ def imagine_images( img_callback=None, ): model = load_model() + # model = model.half() prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts _img_callback = None @@ -156,6 +172,10 @@ def imagine_images( for prompt in prompts: logger.info(f"Generating {prompt.prompt_description()}") seed_everything(prompt.seed) + + # needed when model is in half mode, remove if not using half mode + # torch.set_default_tensor_type(torch.HalfTensor) + uc = None if prompt.prompt_strength != 1.0: uc = model.get_learned_conditioning(1 * [""]) diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index eb062a8..624a3bb 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -6,6 +6,7 @@ from torch import nn, einsum from einops import rearrange, repeat from imaginairy.modules.diffusion.util import checkpoint +from imaginairy.utils import get_device_name, get_device def exists(val): @@ -164,6 +165,9 @@ class CrossAttention(nn.Module): ) def forward(self, x, context=None, mask=None): + if get_device() == "cuda": + return self.forward_cuda(x, context=context, mask=mask) + h = self.heads q = self.to_q(x) @@ -188,6 +192,65 @@ class CrossAttention(nn.Module): out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) + def forward_cuda(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x + + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q_in, k_in, v_in) + ) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats["active_bytes.all.current"] + mem_reserved = stats["reserved_bytes.all.current"] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024**3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError( + f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). " + f"Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free" + ) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum("b i j, b j d -> b i d", s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, "(b h) n d -> b n (h d)", h=h) + del r1 + + return self.to_out(r2) + class BasicTransformerBlock(nn.Module): def __init__( diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 7570d34..179d3e0 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import logging import math @@ -9,7 +10,7 @@ from einops import rearrange from imaginairy.modules.attention import LinearAttention from imaginairy.modules.distributions import DiagonalGaussianDistribution -from imaginairy.utils import instantiate_from_config +from imaginairy.utils import instantiate_from_config, get_device logger = logging.getLogger(__name__) @@ -37,7 +38,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x * torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -120,18 +125,30 @@ class ResnetBlock(nn.Module): ) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -139,7 +156,7 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x + h + return x + h8 class LinAttnBlock(LinearAttention): @@ -169,6 +186,8 @@ class AttnBlock(nn.Module): ) def forward(self, x): + if get_device() == "cuda": + return self.forward_cuda(x) h_ = x h_ = self.norm(h_) q = self.q(h_) @@ -194,6 +213,71 @@ class AttnBlock(nn.Module): return x + h_ + def forward_cuda(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h * w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h * w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats["active_bytes.all.current"] + mem_reserved = stats["reserved_bytes.all.current"] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c) ** (-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h * w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm( + v1, w4 + ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 + def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" @@ -427,31 +511,53 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t return h