|
|
|
@ -7,7 +7,7 @@ https://github.com/CompVis/taming-transformers
|
|
|
|
|
"""
|
|
|
|
|
import itertools
|
|
|
|
|
import logging
|
|
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from contextlib import contextmanager, nullcontext
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -39,6 +39,18 @@ def disabled_train(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ismap(x):
|
|
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
|
|
return False
|
|
|
|
|
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def isimage(x):
|
|
|
|
|
if not isinstance(x, torch.Tensor):
|
|
|
|
|
return False
|
|
|
|
|
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def uniform_on_device(r1, r2, shape, device):
|
|
|
|
|
return (r1 - r2) * torch.rand(*shape, device=device) + r2
|
|
|
|
|
|
|
|
|
@ -1324,6 +1336,188 @@ class DiffusionWrapper(pl.LightningModule):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentFinetuneDiffusion(LatentDiffusion):
|
|
|
|
|
"""
|
|
|
|
|
Basis for different finetunas, such as inpainting or depth2image
|
|
|
|
|
To disable finetuning mode, set finetune_keys to None
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
concat_keys: tuple,
|
|
|
|
|
finetune_keys=(
|
|
|
|
|
"model.diffusion_model.input_blocks.0.0.weight",
|
|
|
|
|
"model_ema.diffusion_modelinput_blocks00weight",
|
|
|
|
|
),
|
|
|
|
|
keep_finetune_dims=4,
|
|
|
|
|
# if model was trained without concat mode before and we would like to keep these channels
|
|
|
|
|
c_concat_log_start=None, # to log reconstruction of c_concat codes
|
|
|
|
|
c_concat_log_end=None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
|
|
|
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
self.finetune_keys = finetune_keys
|
|
|
|
|
self.concat_keys = concat_keys
|
|
|
|
|
self.keep_dims = keep_finetune_dims
|
|
|
|
|
self.c_concat_log_start = c_concat_log_start
|
|
|
|
|
self.c_concat_log_end = c_concat_log_end
|
|
|
|
|
if self.finetune_keys is not None:
|
|
|
|
|
assert ckpt_path is not None, "can only finetune from a given checkpoint"
|
|
|
|
|
if ckpt_path is not None:
|
|
|
|
|
self.init_from_ckpt(ckpt_path, ignore_keys)
|
|
|
|
|
|
|
|
|
|
def init_from_ckpt(self, path, ignore_keys=tuple(), only_model=False):
|
|
|
|
|
sd = torch.load(path, map_location="cpu")
|
|
|
|
|
if "state_dict" in list(sd.keys()):
|
|
|
|
|
sd = sd["state_dict"]
|
|
|
|
|
keys = list(sd.keys())
|
|
|
|
|
for k in keys:
|
|
|
|
|
for ik in ignore_keys:
|
|
|
|
|
if k.startswith(ik):
|
|
|
|
|
print("Deleting key {k} from state_dict.")
|
|
|
|
|
del sd[k]
|
|
|
|
|
|
|
|
|
|
# make it explicit, finetune by including extra input channels
|
|
|
|
|
if self.finetune_keys is not None and k in self.finetune_keys:
|
|
|
|
|
new_entry = None
|
|
|
|
|
for name, param in self.named_parameters():
|
|
|
|
|
if name in self.finetune_keys:
|
|
|
|
|
print(
|
|
|
|
|
f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
|
|
|
|
|
)
|
|
|
|
|
new_entry = torch.zeros_like(param) # zero init
|
|
|
|
|
assert (
|
|
|
|
|
new_entry is not None
|
|
|
|
|
), "did not find matching parameter to modify"
|
|
|
|
|
new_entry[:, : self.keep_dims, ...] = sd[k]
|
|
|
|
|
sd[k] = new_entry
|
|
|
|
|
|
|
|
|
|
missing, unexpected = (
|
|
|
|
|
self.load_state_dict(sd, strict=False)
|
|
|
|
|
if not only_model
|
|
|
|
|
else self.model.load_state_dict(sd, strict=False)
|
|
|
|
|
)
|
|
|
|
|
print(
|
|
|
|
|
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
|
|
|
|
)
|
|
|
|
|
if len(missing) > 0:
|
|
|
|
|
print(f"Missing Keys: {missing}")
|
|
|
|
|
if len(unexpected) > 0:
|
|
|
|
|
print(f"Unexpected Keys: {unexpected}")
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def log_images( # noqa
|
|
|
|
|
self,
|
|
|
|
|
batch,
|
|
|
|
|
N=8,
|
|
|
|
|
n_row=4,
|
|
|
|
|
sample=True,
|
|
|
|
|
ddim_steps=200,
|
|
|
|
|
ddim_eta=1.0,
|
|
|
|
|
return_keys=None,
|
|
|
|
|
quantize_denoised=True,
|
|
|
|
|
inpaint=True,
|
|
|
|
|
plot_denoise_rows=False,
|
|
|
|
|
plot_progressive_rows=True,
|
|
|
|
|
plot_diffusion_rows=True,
|
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
|
unconditional_guidance_label=None,
|
|
|
|
|
use_ema_scope=True,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
|
|
|
|
use_ddim = ddim_steps is not None
|
|
|
|
|
|
|
|
|
|
log = {}
|
|
|
|
|
z, c, x, xrec, xc = self.get_input(
|
|
|
|
|
batch, self.first_stage_key, bs=N, return_first_stage_outputs=True
|
|
|
|
|
)
|
|
|
|
|
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
|
|
|
|
|
N = min(x.shape[0], N)
|
|
|
|
|
n_row = min(x.shape[0], n_row)
|
|
|
|
|
log["inputs"] = x
|
|
|
|
|
log["reconstruction"] = xrec
|
|
|
|
|
if self.model.conditioning_key is not None:
|
|
|
|
|
if hasattr(self.cond_stage_model, "decode"):
|
|
|
|
|
xc = self.cond_stage_model.decode(c)
|
|
|
|
|
log["conditioning"] = xc
|
|
|
|
|
elif self.cond_stage_key in ["caption", "txt"]:
|
|
|
|
|
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
|
|
|
|
log["conditioning"] = xc
|
|
|
|
|
elif self.cond_stage_key in ["class_label", "cls"]:
|
|
|
|
|
# xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
|
|
|
|
|
log["conditioning"] = xc
|
|
|
|
|
elif isimage(xc):
|
|
|
|
|
log["conditioning"] = xc
|
|
|
|
|
if ismap(xc):
|
|
|
|
|
log["original_conditioning"] = self.to_rgb(xc)
|
|
|
|
|
|
|
|
|
|
if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
|
|
|
|
|
log["c_concat_decoded"] = self.decode_first_stage(
|
|
|
|
|
c_cat[:, self.c_concat_log_start : self.c_concat_log_end] # noqa
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if plot_diffusion_rows:
|
|
|
|
|
# get diffusion row
|
|
|
|
|
diffusion_row = []
|
|
|
|
|
z_start = z[:n_row]
|
|
|
|
|
for t in range(self.num_timesteps):
|
|
|
|
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
|
|
|
|
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
|
|
|
|
t = t.to(self.device).long()
|
|
|
|
|
noise = torch.randn_like(z_start)
|
|
|
|
|
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
|
|
|
|
diffusion_row.append(self.decode_first_stage(z_noisy))
|
|
|
|
|
|
|
|
|
|
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
|
|
|
|
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
|
|
|
|
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
|
|
|
|
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
|
|
|
|
log["diffusion_row"] = diffusion_grid
|
|
|
|
|
|
|
|
|
|
if sample:
|
|
|
|
|
# get denoise row
|
|
|
|
|
with ema_scope("Sampling"):
|
|
|
|
|
samples, z_denoise_row = self.sample_log(
|
|
|
|
|
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
|
|
|
|
batch_size=N,
|
|
|
|
|
ddim=use_ddim,
|
|
|
|
|
ddim_steps=ddim_steps,
|
|
|
|
|
eta=ddim_eta,
|
|
|
|
|
)
|
|
|
|
|
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
|
|
|
|
x_samples = self.decode_first_stage(samples)
|
|
|
|
|
log["samples"] = x_samples
|
|
|
|
|
if plot_denoise_rows:
|
|
|
|
|
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
|
|
|
|
log["denoise_row"] = denoise_grid
|
|
|
|
|
|
|
|
|
|
if unconditional_guidance_scale > 1.0:
|
|
|
|
|
uc_cross = self.get_unconditional_conditioning(
|
|
|
|
|
N, unconditional_guidance_label
|
|
|
|
|
)
|
|
|
|
|
uc_cat = c_cat
|
|
|
|
|
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
|
|
|
|
with ema_scope("Sampling with classifier-free guidance"):
|
|
|
|
|
samples_cfg, _ = self.sample_log(
|
|
|
|
|
cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
|
|
|
|
batch_size=N,
|
|
|
|
|
ddim=use_ddim,
|
|
|
|
|
ddim_steps=ddim_steps,
|
|
|
|
|
eta=ddim_eta,
|
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
|
unconditional_conditioning=uc_full,
|
|
|
|
|
)
|
|
|
|
|
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
|
|
|
|
log[
|
|
|
|
|
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
|
|
|
|
] = x_samples_cfg
|
|
|
|
|
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentInpaintDiffusion(LatentDiffusion):
|
|
|
|
|
def __init__( # noqa
|
|
|
|
|
self,
|
|
|
|
@ -1377,3 +1571,152 @@ class LatentInpaintDiffusion(LatentDiffusion):
|
|
|
|
|
if return_first_stage_outputs:
|
|
|
|
|
return z, all_conds, x, xrec, xc
|
|
|
|
|
return z, all_conds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
|
|
|
|
|
"""
|
|
|
|
|
condition on monocular depth estimation
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, depth_stage_config, concat_keys=("midas_in",), **kwargs):
|
|
|
|
|
super().__init__(concat_keys=concat_keys, **kwargs)
|
|
|
|
|
self.depth_model = instantiate_from_config(depth_stage_config)
|
|
|
|
|
self.depth_stage_key = concat_keys[0]
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def get_input(
|
|
|
|
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
|
|
|
|
):
|
|
|
|
|
# note: restricted to non-trainable encoders currently
|
|
|
|
|
assert (
|
|
|
|
|
not self.cond_stage_trainable
|
|
|
|
|
), "trainable cond stages not yet supported for depth2img"
|
|
|
|
|
z, c, x, xrec, xc = super().get_input(
|
|
|
|
|
batch,
|
|
|
|
|
self.first_stage_key,
|
|
|
|
|
return_first_stage_outputs=True,
|
|
|
|
|
force_c_encode=True,
|
|
|
|
|
return_original_cond=True,
|
|
|
|
|
bs=bs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert self.concat_keys is not None
|
|
|
|
|
assert len(self.concat_keys) == 1
|
|
|
|
|
c_cat = []
|
|
|
|
|
for ck in self.concat_keys:
|
|
|
|
|
cc = batch[ck]
|
|
|
|
|
if bs is not None:
|
|
|
|
|
cc = cc[:bs]
|
|
|
|
|
cc = cc.to(self.device)
|
|
|
|
|
cc = self.depth_model(cc)
|
|
|
|
|
cc = torch.nn.functional.interpolate(
|
|
|
|
|
cc,
|
|
|
|
|
size=z.shape[2:],
|
|
|
|
|
mode="bicubic",
|
|
|
|
|
align_corners=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
depth_min, depth_max = torch.amin(
|
|
|
|
|
cc, dim=[1, 2, 3], keepdim=True
|
|
|
|
|
), torch.amax(cc, dim=[1, 2, 3], keepdim=True)
|
|
|
|
|
cc = 2.0 * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.0
|
|
|
|
|
c_cat.append(cc)
|
|
|
|
|
c_cat = torch.cat(c_cat, dim=1)
|
|
|
|
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
|
|
|
|
if return_first_stage_outputs:
|
|
|
|
|
return z, all_conds, x, xrec, xc
|
|
|
|
|
return z, all_conds
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def log_images(self, *args, **kwargs):
|
|
|
|
|
log = super().log_images(*args, **kwargs)
|
|
|
|
|
depth = self.depth_model(args[0][self.depth_stage_key])
|
|
|
|
|
depth_min, depth_max = torch.amin(
|
|
|
|
|
depth, dim=[1, 2, 3], keepdim=True
|
|
|
|
|
), torch.amax(depth, dim=[1, 2, 3], keepdim=True)
|
|
|
|
|
log["depth"] = 2.0 * (depth - depth_min) / (depth_max - depth_min) - 1.0
|
|
|
|
|
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
|
|
|
|
"""
|
|
|
|
|
condition on low-res image (and optionally on some spatial noise augmentation)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
concat_keys=("lr",),
|
|
|
|
|
reshuffle_patch_size=None,
|
|
|
|
|
low_scale_config=None,
|
|
|
|
|
low_scale_key=None,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(concat_keys=concat_keys, **kwargs)
|
|
|
|
|
self.reshuffle_patch_size = reshuffle_patch_size
|
|
|
|
|
self.low_scale_model = None
|
|
|
|
|
if low_scale_config is not None:
|
|
|
|
|
print("Initializing a low-scale model")
|
|
|
|
|
assert low_scale_key is not None
|
|
|
|
|
self.instantiate_low_stage(low_scale_config)
|
|
|
|
|
self.low_scale_key = low_scale_key
|
|
|
|
|
|
|
|
|
|
def instantiate_low_stage(self, config):
|
|
|
|
|
model = instantiate_from_config(config)
|
|
|
|
|
self.low_scale_model = model.eval()
|
|
|
|
|
self.low_scale_model.train = disabled_train
|
|
|
|
|
for param in self.low_scale_model.parameters():
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def get_input(
|
|
|
|
|
self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False
|
|
|
|
|
):
|
|
|
|
|
# note: restricted to non-trainable encoders currently
|
|
|
|
|
assert (
|
|
|
|
|
not self.cond_stage_trainable
|
|
|
|
|
), "trainable cond stages not yet supported for upscaling-ft"
|
|
|
|
|
z, c, x, xrec, xc = super().get_input(
|
|
|
|
|
batch,
|
|
|
|
|
self.first_stage_key,
|
|
|
|
|
return_first_stage_outputs=True,
|
|
|
|
|
force_c_encode=True,
|
|
|
|
|
return_original_cond=True,
|
|
|
|
|
bs=bs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
assert self.concat_keys is not None
|
|
|
|
|
assert len(self.concat_keys) == 1
|
|
|
|
|
# optionally make spatial noise_level here
|
|
|
|
|
c_cat = []
|
|
|
|
|
noise_level = None
|
|
|
|
|
for ck in self.concat_keys:
|
|
|
|
|
cc = batch[ck]
|
|
|
|
|
cc = rearrange(cc, "b h w c -> b c h w")
|
|
|
|
|
if self.reshuffle_patch_size is not None:
|
|
|
|
|
assert isinstance(self.reshuffle_patch_size, int)
|
|
|
|
|
cc = rearrange(
|
|
|
|
|
cc,
|
|
|
|
|
"b c (p1 h) (p2 w) -> b (p1 p2 c) h w",
|
|
|
|
|
p1=self.reshuffle_patch_size,
|
|
|
|
|
p2=self.reshuffle_patch_size,
|
|
|
|
|
)
|
|
|
|
|
if bs is not None:
|
|
|
|
|
cc = cc[:bs]
|
|
|
|
|
cc = cc.to(self.device)
|
|
|
|
|
if self.low_scale_model is not None and ck == self.low_scale_key:
|
|
|
|
|
cc, noise_level = self.low_scale_model(cc)
|
|
|
|
|
c_cat.append(cc)
|
|
|
|
|
c_cat = torch.cat(c_cat, dim=1)
|
|
|
|
|
if noise_level is not None:
|
|
|
|
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
|
|
|
|
|
else:
|
|
|
|
|
all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
|
|
|
|
|
if return_first_stage_outputs:
|
|
|
|
|
return z, all_conds, x, xrec, xc
|
|
|
|
|
return z, all_conds
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def log_images(self, *args, **kwargs):
|
|
|
|
|
log = super().log_images(*args, **kwargs)
|
|
|
|
|
log["lr"] = rearrange(args[0]["lr"], "b h w c -> b c h w")
|
|
|
|
|
return log
|
|
|
|
|