From 4610d7f01d112fb6775fa7a712e86a04575c5ba3 Mon Sep 17 00:00:00 2001 From: Bryce Date: Fri, 25 Nov 2022 08:00:30 -0800 Subject: [PATCH] feature: xformers support add more upscaling code (that doesn't yet work) --- README.md | 2 + imaginairy/config.py | 6 + .../stable-diffusion-x4-upscaling.yaml | 2 +- imaginairy/modules/attention.py | 3 +- imaginairy/modules/diffusion/ddpm.py | 345 +++++++++++++++++- imaginairy/modules/diffusion/upscaling.py | 105 ++++++ 6 files changed, 460 insertions(+), 3 deletions(-) create mode 100644 imaginairy/modules/diffusion/upscaling.py diff --git a/README.md b/README.md index bbac9f0..7353071 100644 --- a/README.md +++ b/README.md @@ -231,6 +231,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog +- feature: xformers support + **6.1.0** - feature: use different default steps and image sizes depending on sampler and model selceted - fix: #110 use proper version in image metadata diff --git a/imaginairy/config.py b/imaginairy/config.py index d6d2c7c..e9167f0 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -49,6 +49,12 @@ MODEL_CONFIGS = [ weights_url="https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt", default_image_size=768, ), + ModelConfig( + short_name="SD-2.0-upscale", + config_path="configs/stable-diffusion-v2-upscaling.yaml", + weights_url="https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/resolve/main/x4-upscaler-ema.ckpt", + default_image_size=512, + ), ] MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS} diff --git a/imaginairy/configs/stable-diffusion-x4-upscaling.yaml b/imaginairy/configs/stable-diffusion-x4-upscaling.yaml index 2db0964..1de3ac5 100644 --- a/imaginairy/configs/stable-diffusion-x4-upscaling.yaml +++ b/imaginairy/configs/stable-diffusion-x4-upscaling.yaml @@ -1,6 +1,6 @@ model: base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + target: imaginairy.modules.diffusion.ddpm.LatentUpscaleDiffusion params: parameterization: "v" low_scale_key: "lr" diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index e065299..e3123ec 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -171,7 +171,8 @@ class CrossAttention(nn.Module): # mask = _global_mask_hack.to(torch.bool) if get_device() == "cuda" or "mps" in get_device(): - return self.forward_splitmem(x, context=context, mask=mask) + if not XFORMERS_IS_AVAILBLE: + return self.forward_splitmem(x, context=context, mask=mask) h = self.heads diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index c4eb623..95133e5 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -7,7 +7,7 @@ https://github.com/CompVis/taming-transformers """ import itertools import logging -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from functools import partial import numpy as np @@ -39,6 +39,18 @@ def disabled_train(self): return self +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 @@ -1324,6 +1336,188 @@ class DiffusionWrapper(pl.LightningModule): return out +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__( + self, + concat_keys: tuple, + finetune_keys=( + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + **kwargs, + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(**kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if self.finetune_keys is not None: + assert ckpt_path is not None, "can only finetune from a given checkpoint" + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {k} from state_dict.") + del sd[k] + + # make it explicit, finetune by including extra input channels + if self.finetune_keys is not None and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only" + ) + new_entry = torch.zeros_like(param) # zero init + assert ( + new_entry is not None + ), "did not find matching parameter to modify" + new_entry[:, : self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images( # noqa + self, + batch, + N=8, + n_row=4, + sample=True, + ddim_steps=200, + ddim_eta=1.0, + return_keys=None, + quantize_denoised=True, + inpaint=True, + plot_denoise_rows=False, + plot_progressive_rows=True, + plot_diffusion_rows=True, + unconditional_guidance_scale=1.0, + unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs, + ): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = {} + z, c, x, xrec, xc = self.get_input( + batch, self.first_stage_key, bs=N, return_first_stage_outputs=True + ) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + # xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ["class_label", "cls"]: + # xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log["conditioning"] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage( + c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa + ) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = [] + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w") + diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w") + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + ) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning( + N, unconditional_guidance_label + ) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log( + cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[ + f"samples_cfg_scale_{unconditional_guidance_scale:.2f}" + ] = x_samples_cfg + + return log + + class LatentInpaintDiffusion(LatentDiffusion): def __init__( # noqa self, @@ -1377,3 +1571,152 @@ class LatentInpaintDiffusion(LatentDiffusion): if return_first_stage_outputs: return z, all_conds, x, xrec, xc return z, all_conds + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), **kwargs): + super().__init__(concat_keys=concat_keys, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for depth2img" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert self.concat_keys is not None + assert len(self.concat_keys) == 1 + c_cat = [] + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin( + cc, dim=[1, 2, 3], keepdim=True + ), torch.amax(cc, dim=[1, 2, 3], keepdim=True) + cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0 + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin( + depth, dim=[1, 2, 3], keepdim=True + ), torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0 + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + + def __init__( + self, + concat_keys=("lr",), + reshuffle_patch_size=None, + low_scale_config=None, + low_scale_key=None, + **kwargs, + ): + super().__init__(concat_keys=concat_keys, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert low_scale_key is not None + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input( + self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False + ): + # note: restricted to non-trainable encoders currently + assert ( + not self.cond_stage_trainable + ), "trainable cond stages not yet supported for upscaling-ft" + z, c, x, xrec, xc = super().get_input( + batch, + self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=bs, + ) + + assert self.concat_keys is not None + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = [] + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, "b h w c -> b c h w") + if self.reshuffle_patch_size is not None: + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange( + cc, + "b c (p1 h) (p2 w) -> b (p1 p2 c) h w", + p1=self.reshuffle_patch_size, + p2=self.reshuffle_patch_size, + ) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if self.low_scale_model is not None and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if noise_level is not None: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w") + return log diff --git a/imaginairy/modules/diffusion/upscaling.py b/imaginairy/modules/diffusion/upscaling.py new file mode 100644 index 0000000..06103c0 --- /dev/null +++ b/imaginairy/modules/diffusion/upscaling.py @@ -0,0 +1,105 @@ +from functools import partial + +import numpy as np +import torch +from torch import nn + +from imaginairy.modules.diffusion.util import extract_into_tensor, make_beta_schedule + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super().__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super().__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0],), device=x.device + ).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level