feature: xformers support

add more upscaling code (that doesn't yet work)
pull/117/head
Bryce 2 years ago committed by Bryce Drennan
parent 8e9e119052
commit 4610d7f01d

@ -231,6 +231,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog ## ChangeLog
- feature: xformers support
**6.1.0** **6.1.0**
- feature: use different default steps and image sizes depending on sampler and model selceted - feature: use different default steps and image sizes depending on sampler and model selceted
- fix: #110 use proper version in image metadata - fix: #110 use proper version in image metadata

@ -49,6 +49,12 @@ MODEL_CONFIGS = [
weights_url="https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt", weights_url="https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt",
default_image_size=768, default_image_size=768,
), ),
ModelConfig(
short_name="SD-2.0-upscale",
config_path="configs/stable-diffusion-v2-upscaling.yaml",
weights_url="https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.ckpt",
default_image_size=512,
),
] ]
MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS} MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS}

@ -1,6 +1,6 @@
model: model:
base_learning_rate: 1.0e-04 base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion target: imaginairy.modules.diffusion.ddpm.LatentUpscaleDiffusion
params: params:
parameterization: "v" parameterization: "v"
low_scale_key: "lr" low_scale_key: "lr"

@ -171,7 +171,8 @@ class CrossAttention(nn.Module):
# mask = _global_mask_hack.to(torch.bool) # mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda" or "mps" in get_device(): if get_device() == "cuda" or "mps" in get_device():
return self.forward_splitmem(x, context=context, mask=mask) if not XFORMERS_IS_AVAILBLE:
return self.forward_splitmem(x, context=context, mask=mask)
h = self.heads h = self.heads

@ -7,7 +7,7 @@ https://github.com/CompVis/taming-transformers
""" """
import itertools import itertools
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
import numpy as np import numpy as np
@ -39,6 +39,18 @@ def disabled_train(self):
return self return self
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def uniform_on_device(r1, r2, shape, device): def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2 return (r1 - r2) * torch.rand(*shape, device=device) + r2
@ -1324,6 +1336,188 @@ class DiffusionWrapper(pl.LightningModule):
return out return out
class LatentFinetuneDiffusion(LatentDiffusion):
"""
Basis for different finetunas, such as inpainting or depth2image
To disable finetuning mode, set finetune_keys to None
"""
def __init__(
self,
concat_keys: tuple,
finetune_keys=(
"model.diffusion_model.input_blocks.0.0.weight",
"model_ema.diffusion_modelinput_blocks00weight",
),
keep_finetune_dims=4,
# if model was trained without concat mode before and we would like to keep these channels
c_concat_log_start=None, # to log reconstruction of c_concat codes
c_concat_log_end=None,
**kwargs,
):
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(**kwargs)
self.finetune_keys = finetune_keys
self.concat_keys = concat_keys
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
if self.finetune_keys is not None:
assert ckpt_path is not None, "can only finetune from a given checkpoint"
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {k} from state_dict.")
del sd[k]
# make it explicit, finetune by including extra input channels
if self.finetune_keys is not None and k in self.finetune_keys:
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
print(
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
)
new_entry = torch.zeros_like(param) # zero init
assert (
new_entry is not None
), "did not find matching parameter to modify"
new_entry[:, : self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = (
self.load_state_dict(sd, strict=False)
if not only_model
else self.model.load_state_dict(sd, strict=False)
)
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
def log_images( # noqa
self,
batch,
N=8,
n_row=4,
sample=True,
ddim_steps=200,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,
inpaint=True,
plot_denoise_rows=False,
plot_progressive_rows=True,
plot_diffusion_rows=True,
unconditional_guidance_scale=1.0,
unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs,
):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = {}
z, c, x, xrec, xc = self.get_input(
batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
)
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
log["conditioning"] = xc
elif self.cond_stage_key in ["class_label", "cls"]:
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
log["conditioning"] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
log["c_concat_decoded"] = self.decode_first_stage(
c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa
)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(
cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(
N, unconditional_guidance_label
)
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(
cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
] = x_samples_cfg
return log
class LatentInpaintDiffusion(LatentDiffusion): class LatentInpaintDiffusion(LatentDiffusion):
def __init__( # noqa def __init__( # noqa
self, self,
@ -1377,3 +1571,152 @@ class LatentInpaintDiffusion(LatentDiffusion):
if return_first_stage_outputs: if return_first_stage_outputs:
return z, all_conds, x, xrec, xc return z, all_conds, x, xrec, xc
return z, all_conds return z, all_conds
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
"""
condition on monocular depth estimation
"""
def __init__(self, depth_stage_config, concat_keys=("midas_in",), **kwargs):
super().__init__(concat_keys=concat_keys, **kwargs)
self.depth_model = instantiate_from_config(depth_stage_config)
self.depth_stage_key = concat_keys[0]
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for depth2img"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert self.concat_keys is not None
assert len(self.concat_keys) == 1
c_cat = []
for ck in self.concat_keys:
cc = batch[ck]
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
cc = self.depth_model(cc)
cc = torch.nn.functional.interpolate(
cc,
size=z.shape[2:],
mode="bicubic",
align_corners=False,
)
depth_min, depth_max = torch.amin(
cc, dim=[1, 2, 3], keepdim=True
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
depth = self.depth_model(args[0][self.depth_stage_key])
depth_min, depth_max = torch.amin(
depth, dim=[1, 2, 3], keepdim=True
), torch.amax(depth, dim=[1, 2, 3], keepdim=True)
log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
return log
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
"""
condition on low-res image (and optionally on some spatial noise augmentation)
"""
def __init__(
self,
concat_keys=("lr",),
reshuffle_patch_size=None,
low_scale_config=None,
low_scale_key=None,
**kwargs,
):
super().__init__(concat_keys=concat_keys, **kwargs)
self.reshuffle_patch_size = reshuffle_patch_size
self.low_scale_model = None
if low_scale_config is not None:
print("Initializing a low-scale model")
assert low_scale_key is not None
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
def instantiate_low_stage(self, config):
model = instantiate_from_config(config)
self.low_scale_model = model.eval()
self.low_scale_model.train = disabled_train
for param in self.low_scale_model.parameters():
param.requires_grad = False
@torch.no_grad()
def get_input(
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
):
# note: restricted to non-trainable encoders currently
assert (
not self.cond_stage_trainable
), "trainable cond stages not yet supported for upscaling-ft"
z, c, x, xrec, xc = super().get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=bs,
)
assert self.concat_keys is not None
assert len(self.concat_keys) == 1
# optionally make spatial noise_level here
c_cat = []
noise_level = None
for ck in self.concat_keys:
cc = batch[ck]
cc = rearrange(cc, "b h w c -> b c h w")
if self.reshuffle_patch_size is not None:
assert isinstance(self.reshuffle_patch_size, int)
cc = rearrange(
cc,
"b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
p1=self.reshuffle_patch_size,
p2=self.reshuffle_patch_size,
)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
if self.low_scale_model is not None and ck == self.low_scale_key:
cc, noise_level = self.low_scale_model(cc)
c_cat.append(cc)
c_cat = torch.cat(c_cat, dim=1)
if noise_level is not None:
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
else:
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
if return_first_stage_outputs:
return z, all_conds, x, xrec, xc
return z, all_conds
@torch.no_grad()
def log_images(self, *args, **kwargs):
log = super().log_images(*args, **kwargs)
log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
return log

@ -0,0 +1,105 @@
from functools import partial
import numpy as np
import torch
from torch import nn
from imaginairy.modules.diffusion.util import extract_into_tensor, make_beta_schedule
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super().__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(
self,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super().__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0],), device=x.device
).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level
Loading…
Cancel
Save