|
|
|
@ -5,14 +5,14 @@ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bb
|
|
|
|
|
https://github.com/CompVis/taming-transformers
|
|
|
|
|
-- merci
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
from torchvision.utils import make_grid
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
@ -21,12 +21,12 @@ from imaginairy.models.autoencoder import (
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.modules.diffusionmodules.util import (
|
|
|
|
|
make_beta_schedule,
|
|
|
|
|
extract_into_tensor,
|
|
|
|
|
noise_like,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.modules.distributions import DiagonalGaussianDistribution
|
|
|
|
|
from imaginairy.utils import print_params, instantiate_from_config
|
|
|
|
|
from imaginairy.utils import log_params, instantiate_from_config
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -77,7 +77,7 @@ class DDPM(pl.LightningModule):
|
|
|
|
|
"x0",
|
|
|
|
|
], 'currently only supporting "eps" and "x0"'
|
|
|
|
|
self.parameterization = parameterization
|
|
|
|
|
print(
|
|
|
|
|
logger.info(
|
|
|
|
|
f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
|
|
|
|
|
)
|
|
|
|
|
self.cond_stage_model = None
|
|
|
|
@ -88,7 +88,7 @@ class DDPM(pl.LightningModule):
|
|
|
|
|
self.channels = channels
|
|
|
|
|
self.use_positional_encodings = use_positional_encodings
|
|
|
|
|
self.model = DiffusionWrapper(unet_config, conditioning_key)
|
|
|
|
|
print_params(self.model)
|
|
|
|
|
log_params(self.model)
|
|
|
|
|
|
|
|
|
|
self.use_scheduler = scheduler_config is not None
|
|
|
|
|
if self.use_scheduler:
|
|
|
|
@ -309,10 +309,12 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
def instantiate_cond_stage(self, config):
|
|
|
|
|
if not self.cond_stage_trainable:
|
|
|
|
|
if config == "__is_first_stage__":
|
|
|
|
|
print("Using first stage also as cond stage.")
|
|
|
|
|
logger.info("Using first stage also as cond stage.")
|
|
|
|
|
self.cond_stage_model = self.first_stage_model
|
|
|
|
|
elif config == "__is_unconditional__":
|
|
|
|
|
print(f"Training {self.__class__.__name__} as an unconditional model.")
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Training {self.__class__.__name__} as an unconditional model."
|
|
|
|
|
)
|
|
|
|
|
self.cond_stage_model = None
|
|
|
|
|
# self.be_unconditional = True
|
|
|
|
|
else:
|
|
|
|
@ -576,11 +578,11 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
bs, nc, h, w = z.shape
|
|
|
|
|
if ks[0] > h or ks[1] > w:
|
|
|
|
|
ks = (min(ks[0], h), min(ks[1], w))
|
|
|
|
|
print("reducing Kernel")
|
|
|
|
|
logger.info("reducing Kernel")
|
|
|
|
|
|
|
|
|
|
if stride[0] > h or stride[1] > w:
|
|
|
|
|
stride = (min(stride[0], h), min(stride[1], w))
|
|
|
|
|
print("reducing stride")
|
|
|
|
|
logger.info("reducing stride")
|
|
|
|
|
|
|
|
|
|
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
|
|
|
|
z, ks, stride, uf=uf
|
|
|
|
@ -643,11 +645,11 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
bs, nc, h, w = x.shape
|
|
|
|
|
if ks[0] > h or ks[1] > w:
|
|
|
|
|
ks = (min(ks[0], h), min(ks[1], w))
|
|
|
|
|
print("reducing Kernel")
|
|
|
|
|
logger.info("reducing Kernel")
|
|
|
|
|
|
|
|
|
|
if stride[0] > h or stride[1] > w:
|
|
|
|
|
stride = (min(stride[0], h), min(stride[1], w))
|
|
|
|
|
print("reducing stride")
|
|
|
|
|
logger.info("reducing stride")
|
|
|
|
|
|
|
|
|
|
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
|
|
|
|
x, ks, stride, df=df
|
|
|
|
@ -774,23 +776,21 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
)
|
|
|
|
|
for bbox in patch_limits
|
|
|
|
|
] # list of length l with tensors of shape (1, 2)
|
|
|
|
|
print(patch_limits_tknzd[0].shape)
|
|
|
|
|
|
|
|
|
|
# cut tknzd crop position from conditioning
|
|
|
|
|
assert isinstance(cond, dict), "cond must be dict to be fed into model"
|
|
|
|
|
cut_cond = cond["c_crossattn"][0][..., :-2].to(self.device)
|
|
|
|
|
print(cut_cond.shape)
|
|
|
|
|
|
|
|
|
|
adapted_cond = torch.stack(
|
|
|
|
|
[torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]
|
|
|
|
|
)
|
|
|
|
|
adapted_cond = rearrange(adapted_cond, "l b n -> (l b) n")
|
|
|
|
|
print(adapted_cond.shape)
|
|
|
|
|
|
|
|
|
|
adapted_cond = self.get_learned_conditioning(adapted_cond)
|
|
|
|
|
print(adapted_cond.shape)
|
|
|
|
|
|
|
|
|
|
adapted_cond = rearrange(
|
|
|
|
|
adapted_cond, "(l b) n d -> l b n d", l=z.shape[-1]
|
|
|
|
|
)
|
|
|
|
|
print(adapted_cond.shape)
|
|
|
|
|
|
|
|
|
|
cond_list = [{"c_crossattn": [e]} for e in adapted_cond]
|
|
|
|
|
|
|
|
|
|