Bryce 2 years ago
parent 910b7b4180
commit 0bb5b6b345

@ -2,7 +2,7 @@
AI imagined images. AI imagined images.
Tested on Linux and OSX(M1). "just works" on Linux and OSX(M1).
```bash ```bash
>> pip install imaginairy >> pip install imaginairy
@ -68,7 +68,7 @@ OR
- img2img actually does # of steps you specify - img2img actually does # of steps you specify
# Models Used # Models Used
- CLIP - CLIP - https://openai.com/blog/clip/
- LDM - Latent Diffusion - LDM - Latent Diffusion
- Stable Diffusion - Stable Diffusion
- https://github.com/CompVis/stable-diffusion - https://github.com/CompVis/stable-diffusion
@ -89,14 +89,16 @@ OR
- Image Generation Features - Image Generation Features
- upscaling - upscaling
- face improvements - face improvements
- image describe feature - image describe feature - https://replicate.com/methexis-inc/img2prompt
- outpainting - outpainting
- inpainting - inpainting
- cross-attention control: - cross-attention control:
- https://github.com/bloc97/CrossAttentionControl/blob/main/CrossAttention_Release_NoImages.ipynb - https://github.com/bloc97/CrossAttentionControl/blob/main/CrossAttention_Release_NoImages.ipynb
- guided generation https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1#scrollTo=UDeXQKbPTdZI
- tiling - tiling
- output show-work videos - output show-work videos
- image variations https://github.com/lstein/stable-diffusion/blob/main/VARIATIONS.md - image variations https://github.com/lstein/stable-diffusion/blob/main/VARIATIONS.md
- textual inversion https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb#scrollTo=50JuJUM8EG1h
- zooming videos? a la disco diffusion - zooming videos? a la disco diffusion

@ -8,6 +8,7 @@ from functools import lru_cache
import PIL import PIL
import numpy as np import numpy as np
import torch import torch
import torch.nn
from PIL import Image from PIL import Image
from einops import rearrange from einops import rearrange
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -69,8 +70,22 @@ def load_img(path, max_height=512, max_width=512):
return 2.0 * image - 1.0, w, h return 2.0 * image - 1.0, w, h
def patch_conv(**patch):
cls = torch.nn.Conv2d
init = cls.__init__
def __init__(self, *args, **kwargs):
return init(self, *args, **kwargs, **patch)
cls.__init__ = __init__
@lru_cache() @lru_cache()
def load_model(): def load_model(tile_mode=False):
if tile_mode:
# generated images are tileable
patch_conv(padding_mode="circular")
config = "configs/stable-diffusion-v1.yaml" config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}") config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config) model = load_model_from_config(config)
@ -143,6 +158,7 @@ def imagine_images(
img_callback=None, img_callback=None,
): ):
model = load_model() model = load_model()
# model = model.half()
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None _img_callback = None
@ -156,6 +172,10 @@ def imagine_images(
for prompt in prompts: for prompt in prompts:
logger.info(f"Generating {prompt.prompt_description()}") logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed) seed_everything(prompt.seed)
# needed when model is in half mode, remove if not using half mode
# torch.set_default_tensor_type(torch.HalfTensor)
uc = None uc = None
if prompt.prompt_strength != 1.0: if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""]) uc = model.get_learned_conditioning(1 * [""])

@ -6,6 +6,7 @@ from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from imaginairy.modules.diffusion.util import checkpoint from imaginairy.modules.diffusion.util import checkpoint
from imaginairy.utils import get_device_name, get_device
def exists(val): def exists(val):
@ -164,6 +165,9 @@ class CrossAttention(nn.Module):
) )
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
if get_device() == "cuda":
return self.forward_cuda(x, context=context, mask=mask)
h = self.heads h = self.heads
q = self.to_q(x) q = self.to_q(x)
@ -188,6 +192,65 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h) out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out) return self.to_out(out)
def forward_cuda(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q_in, k_in, v_in)
)
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024**3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). "
f"Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free"
)
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum("b i j, b j d -> b i d", s2, v)
del s2
del q, k, v
r2 = rearrange(r1, "(b h) n d -> b n (h d)", h=h)
del r1
return self.to_out(r2)
class BasicTransformerBlock(nn.Module): class BasicTransformerBlock(nn.Module):
def __init__( def __init__(

@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder # pytorch_diffusion + derived encoder decoder
import gc
import logging import logging
import math import math
@ -9,7 +10,7 @@ from einops import rearrange
from imaginairy.modules.attention import LinearAttention from imaginairy.modules.attention import LinearAttention
from imaginairy.modules.distributions import DiagonalGaussianDistribution from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config from imaginairy.utils import instantiate_from_config, get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +38,11 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return x * torch.sigmoid(x) t = torch.sigmoid(x)
x *= t
del t
return x
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):
@ -120,18 +125,30 @@ class ResnetBlock(nn.Module):
) )
def forward(self, x, temb): def forward(self, x, temb):
h = x h1 = x
h = self.norm1(h) h2 = self.norm1(h1)
h = nonlinearity(h) del h1
h = self.conv1(h)
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
if temb is not None: if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h) h5 = self.norm2(h4)
h = nonlinearity(h) del h4
h = self.dropout(h)
h = self.conv2(h) h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
@ -139,7 +156,7 @@ class ResnetBlock(nn.Module):
else: else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return x + h return x + h8
class LinAttnBlock(LinearAttention): class LinAttnBlock(LinearAttention):
@ -169,6 +186,8 @@ class AttnBlock(nn.Module):
) )
def forward(self, x): def forward(self, x):
if get_device() == "cuda":
return self.forward_cuda(x)
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
q = self.q(h_) q = self.q(h_)
@ -194,6 +213,71 @@ class AttnBlock(nn.Module):
return x + h_ return x + h_
def forward_cuda(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h * w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h * w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c) ** (-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h * w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(
v1, w4
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
def make_attn(in_channels, attn_type="vanilla"): def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
@ -427,31 +511,53 @@ class Decoder(nn.Module):
temb = None temb = None
# z to block_in # z to block_in
h = self.conv_in(z) h1 = self.conv_in(z)
# middle # middle
h = self.mid.block_1(h, temb) h2 = self.mid.block_1(h1, temb)
h = self.mid.attn_1(h) del h1
h = self.mid.block_2(h, temb)
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()
# upsampling # upsampling
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1): for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb) h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h) t = h
h = self.up[i_level].attn[i_block](t)
del t
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) t = h
h = self.up[i_level].upsample(t)
del t
# end # end
if self.give_pre_end: if self.give_pre_end:
return h return h
h = self.norm_out(h) h1 = self.norm_out(h)
h = nonlinearity(h) del h
h = self.conv_out(h)
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) t = h
h = torch.tanh(t)
del t
return h return h

Loading…
Cancel
Save