diff --git a/README.md b/README.md index e65a962..6c933e6 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,13 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog +**6.0.0a** +- feature: 🎉🎉🎉 Stable Diffusion 2.0 + - Tested on MacOS and Linux + - All samplers working for new 512x512 model + - New inpainting model working + - 768x768 model working for DDIM sampler only + **5.1.0** - feature: add progress image callback diff --git a/imaginairy/configs/stable-diffusion-v2-inference-v.yaml b/imaginairy/configs/stable-diffusion-v2-inference-v.yaml new file mode 100644 index 0000000..ee71794 --- /dev/null +++ b/imaginairy/configs/stable-diffusion-v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: imaginairy.modules.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: imaginairy.modules.diffusion.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: False + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: imaginairy.modules.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/imaginairy/configs/stable-diffusion-v2-inference.yaml b/imaginairy/configs/stable-diffusion-v2-inference.yaml new file mode 100644 index 0000000..4acf817 --- /dev/null +++ b/imaginairy/configs/stable-diffusion-v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: imaginairy.modules.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: imaginairy.modules.diffusion.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: False + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: imaginairy.modules.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/imaginairy/configs/stable-diffusion-v2-inpainting-inference.yaml b/imaginairy/configs/stable-diffusion-v2-inpainting-inference.yaml new file mode 100644 index 0000000..9873f3c --- /dev/null +++ b/imaginairy/configs/stable-diffusion-v2-inpainting-inference.yaml @@ -0,0 +1,158 @@ +model: + base_learning_rate: 5.0e-05 + target: imaginairy.modules.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + unet_config: + target: imaginairy.modules.diffusion.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 9 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: imaginairy.modules.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: imaginairy.modules.encoders.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + +data: + target: ldm.data.laion.WebDataModuleFromConfig + params: + tar_base: null # for concat as in LAION-A + p_unsafe_threshold: 0.1 + filter_word_list: "data/filters.yaml" + max_pwatermark: 0.45 + batch_size: 8 + num_workers: 6 + multinode: True + min_size: 512 + train: + shards: + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -" + - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar" + shuffle: 10000 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.RandomCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + # NOTE use enough shards to avoid empty validation loops in workers + validation: + shards: + - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - " + shuffle: 0 + image_key: jpg + image_transforms: + - target: torchvision.transforms.Resize + params: + size: 512 + interpolation: 3 + - target: torchvision.transforms.CenterCrop + params: + size: 512 + postprocess: + target: ldm.data.laion.AddMask + params: + mode: "512train-large" + p_drop: 0.25 + +lightning: + find_unused_parameters: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 10000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + disabled: False + batch_frequency: 1000 + max_images: 4 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 4 + unconditional_guidance_scale: 5.0 + unconditional_guidance_label: [""] + ddim_steps: 50 # todo check these out for depth2img, + ddim_eta: 0.0 # todo check these out for depth2img, + + trainer: + benchmark: True + val_check_interval: 5000000 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/imaginairy/configs/stable-diffusion-v2-midas-inference.yaml b/imaginairy/configs/stable-diffusion-v2-midas-inference.yaml new file mode 100644 index 0000000..f20c30f --- /dev/null +++ b/imaginairy/configs/stable-diffusion-v2-midas-inference.yaml @@ -0,0 +1,74 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + diff --git a/imaginairy/configs/stable-diffusion-x4-upscaling.yaml b/imaginairy/configs/stable-diffusion-x4-upscaling.yaml new file mode 100644 index 0000000..2db0964 --- /dev/null +++ b/imaginairy/configs/stable-diffusion-x4-upscaling.yaml @@ -0,0 +1,76 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion + params: + parameterization: "v" + low_scale_key: "lr" + linear_start: 0.0001 + linear_end: 0.02 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 128 + channels: 4 + cond_stage_trainable: false + conditioning_key: "hybrid-adm" + monitor: val/loss_simple_ema + scale_factor: 0.08333 + use_ema: False + + low_scale_config: + target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation + params: + noise_schedule_config: # image space + linear_start: 0.0001 + linear_end: 0.02 + max_noise_level: 350 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + num_classes: 1000 # timesteps for noise conditioning (here constant, just need one) + image_size: 128 + in_channels: 7 + out_channels: 4 + model_channels: 256 + attention_resolutions: [ 2,4,8] + num_res_blocks: 2 + channel_mult: [ 1, 2, 2, 4] + disable_self_attentions: [True, True, True, False] + disable_middle_self_attn: False + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + use_linear_in_transformer: True + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though) + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + diff --git a/imaginairy/model_manager.py b/imaginairy/model_manager.py index 11e0c17..0b24b5d 100644 --- a/imaginairy/model_manager.py +++ b/imaginairy/model_manager.py @@ -28,8 +28,20 @@ MODEL_SHORTCUTS = { "configs/stable-diffusion-v1-inpaint.yaml", "https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", ), + "SD-2.0": ( + "configs/stable-diffusion-v2-inference.yaml", + "https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt", + ), + "SD-2.0-inpaint": ( + "configs/stable-diffusion-v2-inpainting-inference.yaml", + "https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/resolve/main/512-inpainting-ema.ckpt", + ), + "SD-2.0-v": ( + "configs/stable-diffusion-v2-inference-v.yaml", + "https://huggingface.co/stabilityai/stable-diffusion-2/resolve/main/768-v-ema.ckpt", + ), } -DEFAULT_MODEL = "SD-1.5" +DEFAULT_MODEL = "SD-2.0" LOADED_MODELS = {} MOST_RECENTLY_LOADED_MODEL = None diff --git a/imaginairy/modules/attention.py b/imaginairy/modules/attention.py index d58bba0..d67ff72 100644 --- a/imaginairy/modules/attention.py +++ b/imaginairy/modules/attention.py @@ -9,8 +9,15 @@ from torch import einsum, nn from imaginairy.modules.diffusion.util import checkpoint from imaginairy.utils import get_device +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + -# feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() @@ -245,7 +252,67 @@ class CrossAttention(nn.Module): return self.to_out(r2) +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + # print( + # f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + # f"{heads} heads." + # ) + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = context if context is not None else x + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if mask is not None: + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, + } + def __init__( self, dim, @@ -254,14 +321,23 @@ class BasicTransformerBlock(nn.Module): dropout=0.0, context_dim=None, gated_ff=True, - checkpoint=True, # noqa + checkpoint=True, + disable_self_attn=False, ): super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( + self.attn2 = attn_cls( query_dim=dim, context_dim=context_dim, heads=n_heads, @@ -280,7 +356,12 @@ class BasicTransformerBlock(nn.Module): def _forward(self, x, context=None): x = x.contiguous() if x.device.type == "mps" else x - x = self.attn1(self.norm1(x)) + x + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x @@ -293,42 +374,73 @@ class SpatialTransformer(nn.Module): and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs """ def __init__( - self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, ): super().__init__() + if context_dim is not None and not isinstance(context_dim, list): + context_dim = [context_dim] self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( - inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, ) for d in range(depth) ] ) - - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape # noqa + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape x_in = x x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") - for block in self.transformer_blocks: - x = block(x, context=context) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) return x + x_in diff --git a/imaginairy/modules/autoencoder.py b/imaginairy/modules/autoencoder.py index 6d03e82..96d6d9a 100644 --- a/imaginairy/modules/autoencoder.py +++ b/imaginairy/modules/autoencoder.py @@ -1,10 +1,12 @@ import logging +from contextlib import contextmanager import pytorch_lightning as pl import torch from imaginairy.modules.diffusion.model import Decoder, Encoder from imaginairy.modules.distributions import DiagonalGaussianDistribution +from imaginairy.modules.ema import LitEma from imaginairy.utils import instantiate_from_config logger = logging.getLogger(__name__) @@ -17,13 +19,15 @@ class AutoencoderKL(pl.LightningModule): lossconfig, embed_dim, ckpt_path=None, - ignore_keys=None, + ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, + ema_decay=None, + learn_logvar=False, ): super().__init__() - ignore_keys = [] if ignore_keys is None else ignore_keys + self.learn_logvar = learn_logvar self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) @@ -33,24 +37,50 @@ class AutoencoderKL(pl.LightningModule): self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim if colorize_nlabels is not None: - assert isinstance(colorize_nlabels, int) + assert type(colorize_nlabels) == int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0.0 < ema_decay < 1.0 + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - def init_from_ckpt(self, path, ignore_keys=None): - ignore_keys = [] if ignore_keys is None else ignore_keys + def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): - logger.info(f"Deleting key {k} from state_dict.") + print("Deleting key {} from state_dict.".format(k)) del sd[k] self.load_state_dict(sd, strict=False) - logger.info(f"Restored from {path}") + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) def encode(self, x): h = self.encoder(x) @@ -63,7 +93,7 @@ class AutoencoderKL(pl.LightningModule): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): # noqa + def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() @@ -78,3 +108,166 @@ class AutoencoderKL(pl.LightningModule): x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode( + torch.randn_like(posterior_ema.sample()) + ) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + # def to_rgb(self, x): + # assert self.image_key == "segmentation" + # if not hasattr(self, "colorize"): + # self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + # x = F.conv2d(x, weight=self.colorize) + # x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + # return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index 6907791..94d4afe 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -5,24 +5,26 @@ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bb https://github.com/CompVis/taming-transformers -- merci """ +import itertools import logging +from contextlib import contextmanager from functools import partial import numpy as np import pytorch_lightning as pl import torch -from einops import rearrange +from einops import rearrange, repeat from torch import nn from torchvision.utils import make_grid from tqdm import tqdm -from imaginairy.log_utils import log_latent from imaginairy.modules.diffusion.util import ( extract_into_tensor, make_beta_schedule, noise_like, ) from imaginairy.modules.distributions import DiagonalGaussianDistribution +from imaginairy.modules.ema import LitEma from imaginairy.utils import instantiate_from_config logger = logging.getLogger(__name__) @@ -42,12 +44,7 @@ def uniform_on_device(r1, r2, shape, device): class DDPM(pl.LightningModule): - """ - classic DDPM with Gaussian diffusion, in image space - - Denoising diffusion probabilistic models - """ - + # classic DDPM with Gaussian diffusion, in image space def __init__( self, unet_config, @@ -55,9 +52,10 @@ class DDPM(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=None, + ignore_keys=[], load_only_unet=False, monitor="val/loss", + use_ema=True, first_stage_key="image", image_size=256, channels=3, @@ -76,18 +74,21 @@ class DDPM(pl.LightningModule): use_positional_encodings=False, learn_logvar=False, logvar_init=0.0, + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, ): super().__init__() - ignore_keys = [] if ignore_keys is None else ignore_keys - assert parameterization in [ "eps", "x0", - ], 'currently only supporting "eps" and "x0"' + "v", + ], 'currently only supporting "eps" and "x0" and "v"' self.parameterization = parameterization - logger.debug( - f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" - ) + # print( + # f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" + # ) self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t @@ -96,6 +97,11 @@ class DDPM(pl.LightningModule): self.channels = channels self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config, conditioning_key) + # count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: @@ -107,10 +113,25 @@ class DDPM(pl.LightningModule): if monitor is not None: self.monitor = monitor + self.make_it_fit = make_it_fit + if reset_ema: + assert ckpt_path is not None if ckpt_path is not None: self.init_from_ckpt( ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet ) + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint." + ) + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print( + " +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ " + ) + assert self.use_ema + self.model_ema.reset_num_updates() self.register_schedule( given_betas=given_betas, @@ -128,6 +149,10 @@ class DDPM(pl.LightningModule): if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + def register_schedule( self, given_betas=None, @@ -170,6 +195,15 @@ class DDPM(pl.LightningModule): self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( @@ -206,13 +240,404 @@ class DDPM(pl.LightningModule): * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like( + self.betas**2 + / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + ) else: raise NotImplementedError("mu not supported") - # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len( + [ + name + for name, _ in itertools.chain( + self.named_parameters(), self.named_buffers() + ) + ] + ) + for name, param in tqdm( + itertools.chain(self.named_parameters(), self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params, + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[ + i % old_shape[0], j % old_shape[1] + ] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = ( + self.load_state_dict(sd, strict=False) + if not only_model + else self.model.load_state_dict(sd, strict=False) + ) + # print( + # f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + # ) + # if len(missing) > 0: + # print(f"Missing Keys:\n {missing}") + # if len(unexpected) > 0: + # print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop( + (batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates, + ) + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == "l1": + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == "l2": + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction="none") + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError( + f"Paramterization {self.parameterization} not yet supported" + ) + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = "train" if self.training else "val" + + loss_dict.update({f"{log_prefix}/loss_simple": loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f"{log_prefix}/loss_vlb": loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f"{log_prefix}/loss": loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, "b h w c -> b c h w") + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + "_ema": loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict( + loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + self.log_dict( + loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True + ) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, "n b c h w -> b n c h w") + denoise_grid = rearrange(denoise_grid, "b n c h w -> (b n) c h w") + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), "1 -> b", b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample( + batch_size=N, return_intermediates=True + ) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + class LatentDiffusion(DDPM): """main class""" @@ -762,9 +1187,9 @@ class LatentDiffusion(DDPM): else: x_recon = self.model(x_noisy, t, **cond) + if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] - log_latent(x_recon, "predicted noise") return x_recon diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 7cfd5c0..76a0a9a 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -1,24 +1,27 @@ -# pylama:ignore=W0613,W0612 # pytorch_diffusion + derived encoder decoder -import gc -import logging import math +from typing import Any, Optional import numpy as np import torch +import torch.nn as nn from einops import rearrange -from torch import nn -from imaginairy.modules.attention import LinearAttention -from imaginairy.modules.distributions import DiagonalGaussianDistribution -from imaginairy.utils import get_device, instantiate_from_config +from imaginairy.modules.attention import MemoryEfficientCrossAttention -logger = logging.getLogger(__name__) +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + # print("No module 'xformers'. Proceeding without it.") def get_timestep_embedding(timesteps, embedding_dim): """ - Matches the implementation in Denoising Diffusion Probabilistic Models: + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly @@ -39,11 +42,7 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - t = torch.sigmoid(x) - x *= t - del t - - return x + return x * torch.sigmoid(x) def Normalize(in_channels, num_groups=32): @@ -126,30 +125,18 @@ class ResnetBlock(nn.Module): ) def forward(self, x, temb): - h1 = x - h2 = self.norm1(h1) - del h1 - - h3 = nonlinearity(h2) - del h2 - - h4 = self.conv1(h3) - del h3 + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) if temb is not None: - h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h5 = self.norm2(h4) - del h4 + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - h6 = nonlinearity(h5) - del h5 - - h7 = self.dropout(h6) - del h6 - - h8 = self.conv2(h7) - del h7 + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -157,14 +144,7 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x + h8 - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + return x + h class AttnBlock(nn.Module): @@ -187,8 +167,6 @@ class AttnBlock(nn.Module): ) def forward(self, x): - if get_device() == "cuda": - return self.forward_cuda(x) h_ = x h_ = self.norm(h_) q = self.q(h_) @@ -214,83 +192,276 @@ class AttnBlock(nn.Module): return x + h_ - def forward_cuda(self, x): + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op: Optional[Any] = None + + def forward(self, x): h_ = x h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) + q = self.q(h_) + k = self.k(h_) v = self.v(h_) # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h * w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h * w) # b,c,hw - del k1 + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) - h_ = torch.zeros_like(k, device=q.device) + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x + out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + # print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() - stats = torch.cuda.memory_stats(q.device) - mem_active = stats["active_bytes.all.current"] - mem_reserved = stats["reserved_bytes.all.current"] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c) ** (-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) - # attend to values - v1 = v.reshape(b, c, h * w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) - h_[:, :, i:end] = torch.bmm( - v1, w4 - ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order - h2 = h_.reshape(b, c, h, w) - del h_ + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) - h3 = self.proj_out(h2) - del h2 + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None - h3 += x + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) - return h3 + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" - logger.debug( - f"making attention of type '{attn_type}' with {in_channels} in_channels" - ) - if attn_type == "vanilla": - return AttnBlock(in_channels) - if attn_type == "none": - return nn.Identity(in_channels) + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h - return LinAttnBlock(in_channels) + def get_last_layer(self): + return self.conv_out.weight class Encoder(nn.Module): @@ -447,9 +618,11 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - logger.debug( - f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." - ) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) # z to block_in self.conv_in = torch.nn.Conv2d( @@ -503,7 +676,6 @@ class Decoder(nn.Module): self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) - self.last_z_shape = None def forward(self, z): # assert z.shape[1:] == self.z_shape[1:] @@ -513,53 +685,136 @@ class Decoder(nn.Module): temb = None # z to block_in - h1 = self.conv_in(z) + h = self.conv_in(z) # middle - h2 = self.mid.block_1(h1, temb) - del h1 - - h3 = self.mid.attn_1(h2) - del h2 - - h = self.mid.block_2(h3, temb) - del h3 - - # prepare for up sampling - gc.collect() - torch.cuda.empty_cache() + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - t = h - h = self.up[i_level].attn[i_block](t) - del t - + h = self.up[i_level].attn[i_block](h) if i_level != 0: - t = h - h = self.up[i_level].upsample(t) - del t + h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h - h1 = self.norm_out(h) - del h + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h - h2 = nonlinearity(h1) - del h1 - h = self.conv_out(h2) - del h2 +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) - if self.tanh_out: - t = h - h = torch.tanh(t) - del t + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) return h @@ -619,15 +874,102 @@ class LatentRescaler(nn.Module): return x +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + class Upsampler(nn.Module): def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 factor_up = 1.0 + (out_size % in_size) - logger.debug( - f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" - ) + # print( + # f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + # ) self.rescaler = LatentRescaler( factor=factor_up, in_channels=in_channels, @@ -657,98 +999,21 @@ class Resize(nn.Module): self.with_conv = learned self.mode = mode if self.with_conv: - logger.info( - f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # noqa - ) - raise NotImplementedError() - # assert in_channels is not None - # # no asymmetric padding in torch conv, must do it ourselves - # self.conv = torch.nn.Conv2d( - # in_channels, in_channels, kernel_size=4, stride=2, padding=1 + # print( + # f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" # ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) def forward(self, x, scale_factor=1.0): if scale_factor == 1.0: return x - - x = torch.nn.functional.interpolate( - x, mode=self.mode, align_corners=False, scale_factor=scale_factor - ) - return x - - -class FirstStagePostProcessor(nn.Module): - def __init__( - self, - ch_mult: list, - in_channels, - pretrained_model: nn.Module = None, - reshape=False, - n_channels=None, - dropout=0.0, - pretrained_config=None, - ): - super().__init__() - if pretrained_config is None: - assert ( - pretrained_model is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model else: - assert ( - pretrained_config is not None - ), 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) - self.proj = nn.Conv2d( - in_channels, n_channels, kernel_size=3, stride=1, padding=1 - ) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append( - ResnetBlock( - in_channels=ch_in, out_channels=m * n_channels, dropout=dropout - ) + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor ) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - @torch.no_grad() - def encode_with_pretrained(self, x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self, x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = nonlinearity(z) - - for submodel, downmodel in zip(self.model, self.downsampler): - z = submodel(z, temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z, "b c h w -> b (h w) c") - return z + return x diff --git a/imaginairy/modules/diffusion/openaimodel.py b/imaginairy/modules/diffusion/openaimodel.py index bc41840..84f6bb2 100644 --- a/imaginairy/modules/diffusion/openaimodel.py +++ b/imaginairy/modules/diffusion/openaimodel.py @@ -4,7 +4,6 @@ from abc import abstractmethod import numpy as np import torch as th import torch.nn.functional as F -from omegaconf.listconfig import ListConfig from torch import nn from imaginairy.modules.attention import SpatialTransformer @@ -12,6 +11,7 @@ from imaginairy.modules.diffusion.util import ( avg_pool_nd, checkpoint, conv_nd, + linear, normalization, timestep_embedding, zero_module, @@ -478,6 +478,10 @@ class UNetModel(nn.Module): context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, ): super().__init__() if use_spatial_transformer: @@ -489,8 +493,9 @@ class UNetModel(nn.Module): assert ( use_spatial_transformer ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig - if isinstance(context_dim, ListConfig): + if type(context_dim) == ListConfig: context_dim = list(context_dim) if num_heads_upsample == -1: @@ -510,7 +515,33 @@ class UNetModel(nn.Module): self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels - self.num_res_blocks = num_res_blocks + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult @@ -525,13 +556,19 @@ class UNetModel(nn.Module): time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - nn.Linear(model_channels, time_embed_dim), + linear(model_channels, time_embed_dim), nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + # print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() self.input_blocks = nn.ModuleList( [ @@ -545,7 +582,7 @@ class UNetModel(nn.Module): ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): + for nr in range(self.num_res_blocks[level]): layers = [ ResBlock( ch, @@ -571,23 +608,32 @@ class UNetModel(nn.Module): if use_spatial_transformer else num_head_channels ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ) ) - ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) @@ -641,12 +687,15 @@ class UNetModel(nn.Module): use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer - else SpatialTransformer( + else SpatialTransformer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, ), ResBlock( ch, @@ -661,7 +710,7 @@ class UNetModel(nn.Module): self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): + for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ ResBlock( @@ -688,24 +737,33 @@ class UNetModel(nn.Module): if use_spatial_transformer else num_head_channels ) - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, + if disable_self_attentions is not None: + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if num_attention_blocks is None or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + ) ) - ) - if level and i == num_res_blocks: + if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( ResBlock( @@ -753,7 +811,7 @@ class UNetModel(nn.Module): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) - def forward(self, x, timesteps=None, context=None, y=None, **kwargs): # noqa + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. @@ -770,7 +828,7 @@ class UNetModel(nn.Module): emb = self.time_embed(t_emb) if self.num_classes is not None: - assert y.shape == (x.shape[0],) + assert y.shape[0] == x.shape[0] emb = emb + self.label_emb(y) h = x.type(self.dtype) @@ -784,5 +842,5 @@ class UNetModel(nn.Module): h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) - - return self.out(h) + else: + return self.out(h) diff --git a/imaginairy/modules/diffusion/util.py b/imaginairy/modules/diffusion/util.py index dc50baf..880ebb7 100644 --- a/imaginairy/modules/diffusion/util.py +++ b/imaginairy/modules/diffusion/util.py @@ -269,6 +269,13 @@ def conv_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + def avg_pool_nd(dims, *args, **kwargs): """Create a 1D, 2D, or 3D average pooling module.""" if dims == 1: diff --git a/imaginairy/modules/ema.py b/imaginairy/modules/ema.py new file mode 100644 index 0000000..9db7015 --- /dev/null +++ b/imaginairy/modules/ema.py @@ -0,0 +1,88 @@ +import torch +from torch import nn + +# https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/ema.py + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/imaginairy/modules/encoders.py b/imaginairy/modules/encoders.py new file mode 100644 index 0000000..2cd2716 --- /dev/null +++ b/imaginairy/modules/encoders.py @@ -0,0 +1,263 @@ +import open_clip +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer + +from imaginairy.utils import get_device + +# https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/encoders/modules.py + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device=None, + max_length=77, + freeze=True, + layer="last", + ): + super().__init__() + assert layer in self.LAYERS + if device is None: + device = get_device() + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu"), pretrained=version + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + # print( + # f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + # ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 654464a..2f2dedf 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -123,6 +123,13 @@ class DDIMSampler: signal_amplification=guidance_scale, ) + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v( + noisy_latent, time_encoding, noise_pred + ) + else: + e_t = noise_pred + batch_size = noisy_latent.shape[0] # select parameters corresponding to the currently considered timestep @@ -146,11 +153,15 @@ class DDIMSampler: schedule.ddim_sqrt_one_minus_alphas[index], device=noisy_latent.device, ) + noisy_latent, predicted_latent = self._p_sample_ddim_formula( + model=self.model, noisy_latent=noisy_latent, noise_pred=noise_pred, + e_t=e_t, sqrt_one_minus_at=sqrt_one_minus_at, a_t=a_t, + time_encoding=time_encoding, sigma_t=sigma_t, a_prev=a_prev, noise_dropout=noise_dropout, @@ -161,19 +172,27 @@ class DDIMSampler: @staticmethod def _p_sample_ddim_formula( + model, noisy_latent, noise_pred, + e_t, sqrt_one_minus_at, a_t, + time_encoding, sigma_t, a_prev, noise_dropout, repeat_noise, temperature, ): - predicted_latent = (noisy_latent - sqrt_one_minus_at * noise_pred) / a_t.sqrt() + if model.parameterization != "v": + predicted_latent = (noisy_latent - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + predicted_latent = model.predict_start_from_z_and_v( + noisy_latent, time_encoding, noise_pred + ) # direction pointing to x_t - dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = ( sigma_t * noise_like(noisy_latent.shape, noisy_latent.device, repeat_noise) diff --git a/setup.py b/setup.py index cbc8941..86a7974 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ setup( "psutil", "pytorch-lightning==1.4.2", "omegaconf==2.1.1", + "open-clip-torch", "einops==0.3.0", "timm>=0.4.12", # for vendored blip "torchdiffeq", diff --git a/tests/samplers/__init__.py b/tests/samplers/__init__.py deleted file mode 100644 index e69de29..0000000