Bryce 2 years ago
parent 910b7b4180
commit 0bb5b6b345

@ -2,7 +2,7 @@
AI imagined images.
Tested on Linux and OSX(M1).
"just works" on Linux and OSX(M1).
```bash
>> pip install imaginairy
@ -68,7 +68,7 @@ OR
- img2img actually does # of steps you specify
# Models Used
- CLIP
- CLIP - https://openai.com/blog/clip/
- LDM - Latent Diffusion
- Stable Diffusion
- https://github.com/CompVis/stable-diffusion
@ -89,14 +89,16 @@ OR
- Image Generation Features
- upscaling
- face improvements
- image describe feature
- image describe feature - https://replicate.com/methexis-inc/img2prompt
- outpainting
- inpainting
- cross-attention control:
- https://github.com/bloc97/CrossAttentionControl/blob/main/CrossAttention_Release_NoImages.ipynb
- guided generation https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1#scrollTo=UDeXQKbPTdZI
- tiling
- output show-work videos
- image variations https://github.com/lstein/stable-diffusion/blob/main/VARIATIONS.md
- textual inversion https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb#scrollTo=50JuJUM8EG1h
- zooming videos? a la disco diffusion

@ -8,6 +8,7 @@ from functools import lru_cache
import PIL
import numpy as np
import torch
import torch.nn
from PIL import Image
from einops import rearrange
from omegaconf import OmegaConf
@ -69,8 +70,22 @@ def load_img(path, max_height=512, max_width=512):
return 2.0 * image - 1.0, w, h
def patch_conv(**patch):
cls = torch.nn.Conv2d
init = cls.__init__
def __init__(self, *args, **kwargs):
return init(self, *args, **kwargs, **patch)
cls.__init__ = __init__
@lru_cache()
def load_model():
def load_model(tile_mode=False):
if tile_mode:
# generated images are tileable
patch_conv(padding_mode="circular")
config = "configs/stable-diffusion-v1.yaml"
config = OmegaConf.load(f"{LIB_PATH}/{config}")
model = load_model_from_config(config)
@ -143,6 +158,7 @@ def imagine_images(
img_callback=None,
):
model = load_model()
# model = model.half()
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
_img_callback = None
@ -156,6 +172,10 @@ def imagine_images(
for prompt in prompts:
logger.info(f"Generating {prompt.prompt_description()}")
seed_everything(prompt.seed)
# needed when model is in half mode, remove if not using half mode
# torch.set_default_tensor_type(torch.HalfTensor)
uc = None
if prompt.prompt_strength != 1.0:
uc = model.get_learned_conditioning(1 * [""])

@ -6,6 +6,7 @@ from torch import nn, einsum
from einops import rearrange, repeat
from imaginairy.modules.diffusion.util import checkpoint
from imaginairy.utils import get_device_name, get_device
def exists(val):
@ -164,6 +165,9 @@ class CrossAttention(nn.Module):
)
def forward(self, x, context=None, mask=None):
if get_device() == "cuda":
return self.forward_cuda(x, context=context, mask=mask)
h = self.heads
q = self.to_q(x)
@ -188,6 +192,65 @@ class CrossAttention(nn.Module):
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
def forward_cuda(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x
q, k, v = map(
lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q_in, k_in, v_in)
)
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024**3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(
f"Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). "
f"Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free"
)
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum("b i d, b j d -> b i j", q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum("b i j, b j d -> b i d", s2, v)
del s2
del q, k, v
r2 = rearrange(r1, "(b h) n d -> b n (h d)", h=h)
del r1
return self.to_out(r2)
class BasicTransformerBlock(nn.Module):
def __init__(

@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import gc
import logging
import math
@ -9,7 +10,7 @@ from einops import rearrange
from imaginairy.modules.attention import LinearAttention
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.utils import instantiate_from_config
from imaginairy.utils import instantiate_from_config, get_device
logger = logging.getLogger(__name__)
@ -37,7 +38,11 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
t = torch.sigmoid(x)
x *= t
del t
return x
def Normalize(in_channels, num_groups=32):
@ -120,18 +125,30 @@ class ResnetBlock(nn.Module):
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@ -139,7 +156,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h
return x + h8
class LinAttnBlock(LinearAttention):
@ -169,6 +186,8 @@ class AttnBlock(nn.Module):
)
def forward(self, x):
if get_device() == "cuda":
return self.forward_cuda(x)
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
@ -194,6 +213,71 @@ class AttnBlock(nn.Module):
return x + h_
def forward_cuda(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h * w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h * w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c) ** (-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h * w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(
v1, w4
) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
@ -427,31 +511,53 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h = self.conv_in(z)
h1 = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
t = h
h = self.up[i_level].attn[i_block](t)
del t
if i_level != 0:
h = self.up[i_level].upsample(h)
t = h
h = self.up[i_level].upsample(t)
del t
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
if self.tanh_out:
h = torch.tanh(h)
t = h
h = torch.tanh(t)
del t
return h

Loading…
Cancel
Save