@ -9,6 +9,7 @@ import itertools
import logging
from contextlib import contextmanager , nullcontext
from functools import partial
from typing import Optional
import numpy as np
import pytorch_lightning as pl
@ -371,7 +372,7 @@ class DDPM(pl.LightningModule):
# we only modify first two axes
assert new_shape [ 2 : ] == old_shape [ 2 : ]
# assumes first axis corresponds to output dim
if not new_shape = = old_shape :
if new_shape ! = old_shape :
new_param = param . clone ( )
old_param = sd [ name ]
if len ( new_shape ) == 1 :
@ -495,7 +496,7 @@ class DDPM(pl.LightningModule):
img = torch . randn ( shape , device = device )
intermediates = [ img ]
for i in tqdm (
reversed ( range ( 0 , self . num_timesteps ) ) ,
reversed ( range ( self . num_timesteps ) ) ,
desc = " Sampling t " ,
total = self . num_timesteps ,
) :
@ -563,9 +564,8 @@ class DDPM(pl.LightningModule):
elif self . parameterization == " v " :
target = self . get_v ( x_start , noise , t )
else :
raise NotImplementedError (
f " Parameterization { self . parameterization } not yet supported "
)
msg = f " Parameterization { self . parameterization } not yet supported "
raise NotImplementedError ( msg )
loss = self . get_loss ( model_out , target , mean = False ) . mean ( dim = [ 1 , 2 , 3 ] )
@ -706,7 +706,7 @@ class DDPM(pl.LightningModule):
lr = self . learning_rate
params = list ( self . model . parameters ( ) )
if self . learn_logvar :
params = params + [ self . logvar ]
params = [ * params , self . logvar ]
opt = torch . optim . AdamW ( params , lr = lr )
return opt
@ -716,7 +716,7 @@ def _TileModeConv2DConvForward(
) :
if self . padding_modeX == self . padding_modeY :
self . padding_mode = self . padding_modeX
return self . _orig_conv_forward ( input , weight , bias ) # noqa
return self . _orig_conv_forward ( input , weight , bias )
w1 = F . pad ( input , self . paddingX , mode = self . padding_modeX )
del input
@ -790,9 +790,7 @@ class LatentDiffusion(DDPM):
if isinstance ( m , nn . Conv2d ) :
m . _initial_padding_mode = m . padding_mode
m . _orig_conv_forward = m . _conv_forward
m . _conv_forward = _TileModeConv2DConvForward . __get__ ( # noqa
m , nn . Conv2d
)
m . _conv_forward = _TileModeConv2DConvForward . __get__ ( m , nn . Conv2d )
self . tile_mode ( tile_mode = False )
def tile_mode ( self , tile_mode ) :
@ -807,16 +805,16 @@ class LatentDiffusion(DDPM):
if m . padding_modeY == m . padding_modeX :
m . padding_mode = m . padding_modeX
m . paddingX = (
m . _reversed_padding_repeated_twice [ 0 ] , # noqa
m . _reversed_padding_repeated_twice [ 1 ] , # noqa
m . _reversed_padding_repeated_twice [ 0 ] ,
m . _reversed_padding_repeated_twice [ 1 ] ,
0 ,
0 ,
)
m . paddingY = (
0 ,
0 ,
m . _reversed_padding_repeated_twice [ 2 ] , # noqa
m . _reversed_padding_repeated_twice [ 3 ] , # noqa
m . _reversed_padding_repeated_twice [ 2 ] ,
m . _reversed_padding_repeated_twice [ 3 ] ,
)
def make_cond_schedule (
@ -896,9 +894,8 @@ class LatentDiffusion(DDPM):
elif isinstance ( encoder_posterior , torch . Tensor ) :
z = encoder_posterior
else :
raise NotImplementedError (
f " encoder_posterior of type ' { type ( encoder_posterior ) } ' not yet implemented "
)
msg = f " encoder_posterior of type ' { type ( encoder_posterior ) } ' not yet implemented "
raise NotImplementedError ( msg )
return self . scale_factor * z
def get_learned_conditioning ( self , c ) :
@ -967,7 +964,7 @@ class LatentDiffusion(DDPM):
: param x : img of size ( bs , c , h , w )
: return : n img crops of size ( n , bs , c , kernel_size [ 0 ] , kernel_size [ 1 ] )
"""
bs , nc , h , w = x . shape # noqa
bs , nc , h , w = x . shape
# number of crops in image
Ly = ( h - kernel_size [ 0 ] ) / / stride [ 0 ] + 1
@ -1167,7 +1164,7 @@ class LatentDiffusion(DDPM):
ks = self . split_input_params [ " ks " ] # eg. (128, 128)
stride = self . split_input_params [ " stride " ] # eg. (64, 64)
h , w = x_noisy . shape [ - 2 : ] # noqa
h , w = x_noisy . shape [ - 2 : ]
fold , unfold , normalization , weighting = self . get_fold_unfold (
x_noisy , ks , stride
@ -1239,9 +1236,7 @@ class LatentDiffusion(DDPM):
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [
torch . LongTensor ( self . bbox_tokenizer . _crop_encoder ( bbox ) ) [ # noqa
None
] . to ( # noqa
torch . LongTensor ( self . bbox_tokenizer . _crop_encoder ( bbox ) ) [ None ] . to (
self . device
)
for bbox in patch_limits
@ -1292,7 +1287,7 @@ class LatentDiffusion(DDPM):
return x_recon
def p_losses ( self , x_start , cond , t , noise = None ) : # noqa
def p_losses ( self , x_start , cond , t , noise = None ) :
noise = noise if noise is not None else torch . randn_like ( x_start )
x_noisy = self . q_sample ( x_start = x_start , t = t , noise = noise )
model_output = self . apply_model ( x_noisy , t , cond )
@ -1374,7 +1369,7 @@ class LatentDiffusion(DDPM):
return model_mean , posterior_variance , posterior_log_variance
@torch.no_grad ( )
def p_sample ( # noqa
def p_sample (
self ,
x ,
c ,
@ -1609,7 +1604,7 @@ class LatentDiffusion(DDPM):
if inpaint :
# make a simple center square
b , h , w = z . shape [ 0 ] , z . shape [ 2 ] , z . shape [ 3 ]
b , h , w = z . shape [ 0 ] , z . shape [ 2 ] , z . shape [ 3 ] # noqa
mask = torch . ones ( N , h , w ) . to ( self . device )
# zeros will be filled in
mask [ : , h / / 4 : 3 * h / / 4 , w / / 4 : 3 * w / / 4 ] = 0.0
@ -1674,9 +1669,8 @@ class LatentDiffusion(DDPM):
logger . info ( " Training the full unet " )
params = list ( self . model . parameters ( ) )
else :
raise ValueError (
f " Unrecognised setting for unet_trainable: { self . unet_trainable } "
)
msg = f " Unrecognised setting for unet_trainable: { self . unet_trainable } "
raise ValueError ( msg )
if self . cond_stage_trainable :
logger . info (
@ -1706,7 +1700,7 @@ class LatentDiffusion(DDPM):
def to_rgb ( self , x ) :
x = x . float ( )
if not hasattr ( self , " colorize " ) :
self . colorize = torch . randn ( 3 , x . shape [ 1 ] , 1 , 1 ) . to ( x ) # noqa
self . colorize = torch . randn ( 3 , x . shape [ 1 ] , 1 , 1 ) . to ( x )
x = nn . functional . conv2d ( x , weight = self . colorize )
x = 2.0 * ( x - x . min ( ) ) / ( x . max ( ) - x . min ( ) ) - 1.0
return x
@ -1719,17 +1713,19 @@ class DiffusionWrapper(pl.LightningModule):
self . conditioning_key = conditioning_key
assert self . conditioning_key in [ None , " concat " , " crossattn " , " hybrid " , " adm " ]
def forward ( self , x , t , c_concat : list = None , c_crossattn : list = None ) :
def forward (
self , x , t , c_concat : Optional [ list ] = None , c_crossattn : Optional [ list ] = None
) :
if self . conditioning_key is None :
out = self . diffusion_model ( x , t )
elif self . conditioning_key == " concat " :
xc = torch . cat ( [ x ] + c_concat , dim = 1 )
xc = torch . cat ( [ x , * c_concat ] , dim = 1 )
out = self . diffusion_model ( xc , t )
elif self . conditioning_key == " crossattn " :
cc = torch . cat ( c_crossattn , 1 )
out = self . diffusion_model ( x , t , context = cc )
elif self . conditioning_key == " hybrid " :
xc = torch . cat ( [ x ] + c_concat , dim = 1 )
xc = torch . cat ( [ x , * c_concat ] , dim = 1 )
cc = torch . cat ( c_crossattn , 1 )
out = self . diffusion_model ( xc , t , context = cc )
elif self . conditioning_key == " adm " :
@ -1818,7 +1814,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
# print(f"Unexpected Keys: {unexpected}")
@torch.no_grad ( )
def log_images ( # noqa
def log_images (
self ,
batch ,
N = 8 ,
@ -1866,7 +1862,7 @@ class LatentFinetuneDiffusion(LatentDiffusion):
if not ( self . c_concat_log_start is None and self . c_concat_log_end is None ) :
log [ " c_concat_decoded " ] = self . decode_first_stage (
c_cat [ : , self . c_concat_log_start : self . c_concat_log_end ] # noqa
c_cat [ : , self . c_concat_log_start : self . c_concat_log_end ]
)
if plot_diffusion_rows :
@ -1929,11 +1925,11 @@ class LatentFinetuneDiffusion(LatentDiffusion):
class LatentInpaintDiffusion ( LatentDiffusion ) :
def __init__ ( # noqa
def __init__ (
self ,
concat_keys = ( " mask " , " masked_image " ) ,
masked_image_key = " masked_image " ,
finetune_keys = None , # noqa
finetune_keys = None ,
* args ,
* * kwargs ,
) :