|
|
|
@ -631,49 +631,48 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def encode_first_stage(self, x):
|
|
|
|
|
if hasattr(self, "split_input_params"):
|
|
|
|
|
if self.split_input_params["patch_distributed_vq"]:
|
|
|
|
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
|
|
|
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
|
|
|
|
df = self.split_input_params["vqf"]
|
|
|
|
|
self.split_input_params["original_image_size"] = x.shape[-2:]
|
|
|
|
|
bs, nc, h, w = x.shape
|
|
|
|
|
if ks[0] > h or ks[1] > w:
|
|
|
|
|
ks = (min(ks[0], h), min(ks[1], w))
|
|
|
|
|
logger.info("reducing Kernel")
|
|
|
|
|
if (
|
|
|
|
|
hasattr(self, "split_input_params")
|
|
|
|
|
and self.split_input_params["patch_distributed_vq"]
|
|
|
|
|
):
|
|
|
|
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
|
|
|
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
|
|
|
|
df = self.split_input_params["vqf"]
|
|
|
|
|
self.split_input_params["original_image_size"] = x.shape[-2:]
|
|
|
|
|
bs, nc, h, w = x.shape
|
|
|
|
|
if ks[0] > h or ks[1] > w:
|
|
|
|
|
ks = (min(ks[0], h), min(ks[1], w))
|
|
|
|
|
logger.info("reducing Kernel")
|
|
|
|
|
|
|
|
|
|
if stride[0] > h or stride[1] > w:
|
|
|
|
|
stride = (min(stride[0], h), min(stride[1], w))
|
|
|
|
|
logger.info("reducing stride")
|
|
|
|
|
if stride[0] > h or stride[1] > w:
|
|
|
|
|
stride = (min(stride[0], h), min(stride[1], w))
|
|
|
|
|
logger.info("reducing stride")
|
|
|
|
|
|
|
|
|
|
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
|
|
|
|
x, ks, stride, df=df
|
|
|
|
|
)
|
|
|
|
|
z = unfold(x) # (bn, nc * prod(**ks), L)
|
|
|
|
|
# Reshape to img shape
|
|
|
|
|
z = z.view(
|
|
|
|
|
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
|
|
|
|
|
) # (bn, nc, ks[0], ks[1], L )
|
|
|
|
|
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
|
|
|
|
x, ks, stride, df=df
|
|
|
|
|
)
|
|
|
|
|
z = unfold(x) # (bn, nc * prod(**ks), L)
|
|
|
|
|
# Reshape to img shape
|
|
|
|
|
z = z.view(
|
|
|
|
|
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
|
|
|
|
|
) # (bn, nc, ks[0], ks[1], L )
|
|
|
|
|
|
|
|
|
|
output_list = [
|
|
|
|
|
self.first_stage_model.encode(z[:, :, :, :, i])
|
|
|
|
|
for i in range(z.shape[-1])
|
|
|
|
|
]
|
|
|
|
|
output_list = [
|
|
|
|
|
self.first_stage_model.encode(z[:, :, :, :, i])
|
|
|
|
|
for i in range(z.shape[-1])
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
o = torch.stack(output_list, axis=-1)
|
|
|
|
|
o = o * weighting
|
|
|
|
|
o = torch.stack(output_list, axis=-1)
|
|
|
|
|
o = o * weighting
|
|
|
|
|
|
|
|
|
|
# Reverse reshape to img shape
|
|
|
|
|
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
|
|
|
|
# stitch crops together
|
|
|
|
|
decoded = fold(o)
|
|
|
|
|
decoded = decoded / normalization
|
|
|
|
|
return decoded
|
|
|
|
|
# Reverse reshape to img shape
|
|
|
|
|
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
|
|
|
|
# stitch crops together
|
|
|
|
|
decoded = fold(o)
|
|
|
|
|
decoded = decoded / normalization
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
return self.first_stage_model.encode(x)
|
|
|
|
|
else:
|
|
|
|
|
return self.first_stage_model.encode(x)
|
|
|
|
|
return self.first_stage_model.encode(x)
|
|
|
|
|
|
|
|
|
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
|
|
|
|
|
|
|
|
@ -814,8 +813,8 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
|
|
|
|
|
if isinstance(x_recon, tuple) and not return_ids:
|
|
|
|
|
return x_recon[0]
|
|
|
|
|
else:
|
|
|
|
|
return x_recon
|
|
|
|
|
|
|
|
|
|
return x_recon
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(
|
|
|
|
|
self,
|
|
|
|
@ -851,16 +850,16 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
if clip_denoised:
|
|
|
|
|
x_recon.clamp_(-1.0, 1.0)
|
|
|
|
|
if quantize_denoised:
|
|
|
|
|
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
|
|
|
|
|
x_recon, _, _ = self.first_stage_model.quantize(x_recon)
|
|
|
|
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
|
|
|
|
|
x_start=x_recon, x_t=x, t=t
|
|
|
|
|
)
|
|
|
|
|
if return_codebook_ids:
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance, logits
|
|
|
|
|
elif return_x0:
|
|
|
|
|
if return_x0:
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance, x_recon
|
|
|
|
|
else:
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(
|
|
|
|
@ -890,10 +889,7 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
score_corrector=score_corrector,
|
|
|
|
|
corrector_kwargs=corrector_kwargs,
|
|
|
|
|
)
|
|
|
|
|
if return_codebook_ids:
|
|
|
|
|
raise DeprecationWarning("Support dropped.")
|
|
|
|
|
model_mean, _, model_log_variance, logits = outputs
|
|
|
|
|
elif return_x0:
|
|
|
|
|
if return_x0:
|
|
|
|
|
model_mean, _, model_log_variance, x0 = outputs
|
|
|
|
|
else:
|
|
|
|
|
model_mean, _, model_log_variance = outputs
|
|
|
|
@ -904,17 +900,13 @@ class LatentDiffusion(DDPM):
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
|
|
|
|
|
if return_codebook_ids:
|
|
|
|
|
return model_mean + nonzero_mask * (
|
|
|
|
|
0.5 * model_log_variance
|
|
|
|
|
).exp() * noise, logits.argmax(dim=1)
|
|
|
|
|
if return_x0:
|
|
|
|
|
return (
|
|
|
|
|
model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise,
|
|
|
|
|
x0,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiffusionWrapper(pl.LightningModule):
|
|
|
|
|