|
|
@ -5,7 +5,9 @@ import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from tqdm import tqdm
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from imaginairy.img_log import log_latent
|
|
|
|
from imaginairy.modules.diffusion.util import (
|
|
|
|
from imaginairy.modules.diffusion.util import (
|
|
|
|
|
|
|
|
extract_into_tensor,
|
|
|
|
make_ddim_sampling_parameters,
|
|
|
|
make_ddim_sampling_parameters,
|
|
|
|
make_ddim_timesteps,
|
|
|
|
make_ddim_timesteps,
|
|
|
|
noise_like,
|
|
|
|
noise_like,
|
|
|
@ -172,7 +174,7 @@ class PLMSSampler:
|
|
|
|
img = torch.randn(shape, device="cpu").to(device)
|
|
|
|
img = torch.randn(shape, device="cpu").to(device)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
img = x_T
|
|
|
|
img = x_T
|
|
|
|
|
|
|
|
log_latent(img, "initial img")
|
|
|
|
if timesteps is None:
|
|
|
|
if timesteps is None:
|
|
|
|
timesteps = (
|
|
|
|
timesteps = (
|
|
|
|
self.ddpm_num_timesteps
|
|
|
|
self.ddpm_num_timesteps
|
|
|
@ -217,7 +219,7 @@ class PLMSSampler:
|
|
|
|
) # TODO: deterministic forward pass?
|
|
|
|
) # TODO: deterministic forward pass?
|
|
|
|
img = img_orig * mask + (1.0 - mask) * img
|
|
|
|
img = img_orig * mask + (1.0 - mask) * img
|
|
|
|
|
|
|
|
|
|
|
|
outs = self.p_sample_plms(
|
|
|
|
img, pred_x0, e_t = self.p_sample_plms(
|
|
|
|
img,
|
|
|
|
img,
|
|
|
|
cond,
|
|
|
|
cond,
|
|
|
|
ts,
|
|
|
|
ts,
|
|
|
@ -233,7 +235,6 @@ class PLMSSampler:
|
|
|
|
old_eps=old_eps,
|
|
|
|
old_eps=old_eps,
|
|
|
|
t_next=ts_next,
|
|
|
|
t_next=ts_next,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
img, pred_x0, e_t = outs
|
|
|
|
|
|
|
|
old_eps.append(e_t)
|
|
|
|
old_eps.append(e_t)
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
old_eps.pop(0)
|
|
|
|
old_eps.pop(0)
|
|
|
@ -277,7 +278,11 @@ class PLMSSampler:
|
|
|
|
t_in = torch.cat([t] * 2)
|
|
|
|
t_in = torch.cat([t] * 2)
|
|
|
|
c_in = torch.cat([unconditional_conditioning, c])
|
|
|
|
c_in = torch.cat([unconditional_conditioning, c])
|
|
|
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
|
|
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
|
|
|
|
|
|
log_latent(e_t_uncond, "noise pred uncond")
|
|
|
|
|
|
|
|
log_latent(e_t, "noise pred cond")
|
|
|
|
|
|
|
|
|
|
|
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
|
|
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
|
|
|
|
|
|
log_latent(e_t, "noise pred combined")
|
|
|
|
|
|
|
|
|
|
|
|
if score_corrector is not None:
|
|
|
|
if score_corrector is not None:
|
|
|
|
assert self.model.parameterization == "eps"
|
|
|
|
assert self.model.parameterization == "eps"
|
|
|
@ -326,6 +331,7 @@ class PLMSSampler:
|
|
|
|
return x_prev, pred_x0
|
|
|
|
return x_prev, pred_x0
|
|
|
|
|
|
|
|
|
|
|
|
e_t = get_model_output(x, t)
|
|
|
|
e_t = get_model_output(x, t)
|
|
|
|
|
|
|
|
|
|
|
|
if len(old_eps) == 0:
|
|
|
|
if len(old_eps) == 0:
|
|
|
|
# Pseudo Improved Euler (2nd order)
|
|
|
|
# Pseudo Improved Euler (2nd order)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
|
@ -344,5 +350,97 @@ class PLMSSampler:
|
|
|
|
) / 24
|
|
|
|
) / 24
|
|
|
|
|
|
|
|
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
|
|
|
|
|
|
log_latent(x_prev, "x_prev")
|
|
|
|
|
|
|
|
log_latent(pred_x0, "pred_x0")
|
|
|
|
|
|
|
|
|
|
|
|
return x_prev, pred_x0, e_t
|
|
|
|
return x_prev, pred_x0, e_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
|
|
|
def stochastic_encode(self, init_latent, t, noise=None):
|
|
|
|
|
|
|
|
# fast, but does not allow for exact reconstruction
|
|
|
|
|
|
|
|
# t serves as an index to gather the correct alphas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
|
|
|
|
|
|
|
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if noise is None:
|
|
|
|
|
|
|
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
|
|
|
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
|
|
|
|
|
|
|
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
|
|
|
|
|
|
|
* noise
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
|
|
|
def decode(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
x_latent,
|
|
|
|
|
|
|
|
cond,
|
|
|
|
|
|
|
|
t_start,
|
|
|
|
|
|
|
|
unconditional_guidance_scale=1.0,
|
|
|
|
|
|
|
|
unconditional_conditioning=None,
|
|
|
|
|
|
|
|
img_callback=None,
|
|
|
|
|
|
|
|
score_corrector=None,
|
|
|
|
|
|
|
|
temperature=1.0,
|
|
|
|
|
|
|
|
mask=None,
|
|
|
|
|
|
|
|
orig_latent=None,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
timesteps = self.ddim_timesteps[:t_start]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
|
|
|
|
|
|
|
|
x_dec = x_latent
|
|
|
|
|
|
|
|
old_eps = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, step in enumerate(iterator):
|
|
|
|
|
|
|
|
index = total_steps - i - 1
|
|
|
|
|
|
|
|
ts = torch.full(
|
|
|
|
|
|
|
|
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
ts_next = torch.full(
|
|
|
|
|
|
|
|
(x_latent.shape[0],),
|
|
|
|
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
|
|
|
|
device=x_latent.device,
|
|
|
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|
|
|
assert orig_latent is not None
|
|
|
|
|
|
|
|
xdec_orig = self.model.q_sample(orig_latent, ts)
|
|
|
|
|
|
|
|
log_latent(xdec_orig, "xdec_orig")
|
|
|
|
|
|
|
|
log_latent(xdec_orig * mask, "masked_xdec_orig")
|
|
|
|
|
|
|
|
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
|
|
|
|
|
|
|
log_latent(x_dec, "x_dec")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_dec, pred_x0, e_t = self.p_sample_plms(
|
|
|
|
|
|
|
|
x_dec,
|
|
|
|
|
|
|
|
cond,
|
|
|
|
|
|
|
|
ts,
|
|
|
|
|
|
|
|
index=index,
|
|
|
|
|
|
|
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
|
|
|
|
|
|
unconditional_conditioning=unconditional_conditioning,
|
|
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
|
|
old_eps=old_eps,
|
|
|
|
|
|
|
|
t_next=ts_next,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
# original_loss = ((x_dec - x_latent).abs().mean()*70)
|
|
|
|
|
|
|
|
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
|
|
|
|
|
|
|
|
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
|
|
|
|
|
|
|
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
|
|
|
|
|
|
|
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
|
|
|
|
|
|
|
## x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old_eps.append(e_t)
|
|
|
|
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
|
|
|
|
old_eps.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if img_callback:
|
|
|
|
|
|
|
|
img_callback(x_dec, "x_dec")
|
|
|
|
|
|
|
|
img_callback(pred_x0, "pred_x0")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_latent(x_dec, f"x_dec {i}")
|
|
|
|
|
|
|
|
log_latent(pred_x0, f"pred_x0 {i}")
|
|
|
|
|
|
|
|
return x_dec
|
|
|
|