# pytorch_diffusion + derived encoder decoder import gc import logging import math import numpy as np import torch import torch.nn as nn from einops import rearrange from imaginairy.modules.attention import LinearAttention from imaginairy.modules.distributions import DiagonalGaussianDistribution from imaginairy.utils import get_device, instantiate_from_config logger = logging.getLogger(__name__) def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def nonlinearity(x): # swish t = torch.sigmoid(x) x *= t del t return x def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): h1 = x h2 = self.norm1(h1) del h1 h3 = nonlinearity(h2) del h2 h4 = self.conv1(h3) del h3 if temb is not None: h4 = h4 + self.temb_proj(nonlinearity(temb))[:, :, None, None] h5 = self.norm2(h4) del h4 h6 = nonlinearity(h5) del h5 h7 = self.dropout(h6) del h6 h8 = self.conv2(h7) del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h8 class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): if get_device() == "cuda": return self.forward_cuda(x) h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h * w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h * w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h * w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return x + h_ def forward_cuda(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_) k1 = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q1.shape q2 = q1.reshape(b, c, h * w) del q1 q = q2.permute(0, 2, 1) # b,hw,c del q2 k = k1.reshape(b, c, h * w) # b,c,hw del k1 h_ = torch.zeros_like(k, device=q.device) stats = torch.cuda.memory_stats(q.device) mem_active = stats["active_bytes.all.current"] mem_reserved = stats["reserved_bytes.all.current"] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c) ** (-0.5)) del w1 w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) del w2 # attend to values v1 = v.reshape(b, c, h * w) w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) del w3 h_[:, :, i:end] = torch.bmm( v1, w4 ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ h3 = self.proj_out(h2) del h2 h3 += x return h3 def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" logger.debug( f"making attention of type '{attn_type}' with {in_channels} in_channels" ) if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) logger.debug( f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." ) # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, z): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h1 = self.conv_in(z) # middle h2 = self.mid.block_1(h1, temb) del h1 h3 = self.mid.attn_1(h2) del h2 h = self.mid.block_2(h3, temb) del h3 # prepare for up sampling gc.collect() torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: t = h h = self.up[i_level].attn[i_block](t) del t if i_level != 0: t = h h = self.up[i_level].upsample(t) del t # end if self.give_pre_end: return h h1 = self.norm_out(h) del h h2 = nonlinearity(h1) del h1 h = self.conv_out(h2) del h2 if self.tanh_out: t = h h = torch.tanh(t) del t return h class LatentRescaler(nn.Module): def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d( in_channels, mid_channels, kernel_size=3, stride=1, padding=1 ) self.res_block1 = nn.ModuleList( [ ResnetBlock( in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0, ) for _ in range(depth) ] ) self.attn = AttnBlock(mid_channels) self.res_block2 = nn.ModuleList( [ ResnetBlock( in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0, ) for _ in range(depth) ] ) self.conv_out = nn.Conv2d( mid_channels, out_channels, kernel_size=1, ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) x = torch.nn.functional.interpolate( x, size=( int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)), ), ) x = self.attn(x) for block in self.res_block2: x = block(x, None) x = self.conv_out(x) return x class Upsampler(nn.Module): def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 factor_up = 1.0 + (out_size % in_size) logger.debug( f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" ) self.rescaler = LatentRescaler( factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels, ) self.decoder = Decoder( out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, attn_resolutions=[], in_channels=None, ch=in_channels, ch_mult=[ch_mult for _ in range(num_blocks)], ) def forward(self, x): x = self.rescaler(x) x = self.decoder(x) return x class Resize(nn.Module): def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: logger.info( f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=4, stride=2, padding=1 ) def forward(self, x, scale_factor=1.0): if scale_factor == 1.0: return x else: x = torch.nn.functional.interpolate( x, mode=self.mode, align_corners=False, scale_factor=scale_factor ) return x class FirstStagePostProcessor(nn.Module): def __init__( self, ch_mult: list, in_channels, pretrained_model: nn.Module = None, reshape=False, n_channels=None, dropout=0.0, pretrained_config=None, ): super().__init__() if pretrained_config is None: assert ( pretrained_model is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: assert ( pretrained_config is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape if n_channels is None: n_channels = self.pretrained_model.encoder.ch self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) self.proj = nn.Conv2d( in_channels, n_channels, kernel_size=3, stride=1, padding=1 ) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: blocks.append( ResnetBlock( in_channels=ch_in, out_channels=m * n_channels, dropout=dropout ) ) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() # self.pretrained_model.train = False for param in self.pretrained_model.parameters(): param.requires_grad = False @torch.no_grad() def encode_with_pretrained(self, x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() return c def forward(self, x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) for submodel, downmodel in zip(self.model, self.downsampler): z = submodel(z, temb=None) z = downmodel(z) if self.do_reshape: z = rearrange(z, "b c h w -> b (h w) c") return z