|
|
|
@ -11,7 +11,7 @@ from functools import partial
|
|
|
|
|
import numpy as np
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch import nn
|
|
|
|
|
from einops import rearrange
|
|
|
|
|
from torchvision.utils import make_grid
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
@ -28,9 +28,11 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def disabled_train(self, mode=True):
|
|
|
|
|
"""Overwrite model.train with this function to make sure train/eval mode
|
|
|
|
|
does not change anymore."""
|
|
|
|
|
def disabled_train(self):
|
|
|
|
|
"""
|
|
|
|
|
Overwrite model.train with this function to make sure train/eval mode
|
|
|
|
|
does not change anymore.
|
|
|
|
|
"""
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -265,6 +267,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
self.instantiate_first_stage(first_stage_config)
|
|
|
|
|
self.instantiate_cond_stage(cond_stage_config)
|
|
|
|
|
self.cond_stage_forward = cond_stage_forward
|
|
|
|
|
self.cond_ids = None
|
|
|
|
|
self.clip_denoised = False
|
|
|
|
|
self.bbox_tokenizer = None
|
|
|
|
|
|
|
|
|
@ -346,13 +349,13 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
self.cond_stage_model = model
|
|
|
|
|
|
|
|
|
|
def _get_denoise_row_from_list(
|
|
|
|
|
self, samples, desc="", force_no_decoder_quantization=False
|
|
|
|
|
self, samples, desc=""
|
|
|
|
|
):
|
|
|
|
|
denoise_row = []
|
|
|
|
|
for zd in tqdm(samples, desc=desc):
|
|
|
|
|
denoise_row.append(
|
|
|
|
|
self.decode_first_stage(
|
|
|
|
|
zd.to(self.device), force_not_quantize=force_no_decoder_quantization
|
|
|
|
|
zd.to(self.device)
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
n_imgs_per_row = len(denoise_row)
|
|
|
|
@ -547,7 +550,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
else:
|
|
|
|
|
xc = x
|
|
|
|
|
if not self.cond_stage_trainable or force_c_encode:
|
|
|
|
|
if isinstance(xc, dict) or isinstance(xc, list):
|
|
|
|
|
if isinstance(xc, (dict, list)):
|
|
|
|
|
# import pudb; pudb.set_trace()
|
|
|
|
|
c = self.get_learned_conditioning(xc)
|
|
|
|
|
else:
|
|
|
|
@ -577,7 +580,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
|
|
|
|
def decode_first_stage(self, z, predict_cids=False):
|
|
|
|
|
if predict_cids:
|
|
|
|
|
if z.dim() == 4:
|
|
|
|
|
z = torch.argmax(z.exp(), dim=1).long()
|
|
|
|
|