|
|
@ -41,6 +41,7 @@ class PLMSSampler:
|
|
|
|
temperature=1.0,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
initial_latent=None,
|
|
|
|
initial_latent=None,
|
|
|
|
|
|
|
|
t_start=None,
|
|
|
|
quantize_denoised=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
**kwargs,
|
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
):
|
|
|
@ -61,13 +62,18 @@ class PLMSSampler:
|
|
|
|
|
|
|
|
|
|
|
|
log_latent(initial_latent, "initial latent")
|
|
|
|
log_latent(initial_latent, "initial latent")
|
|
|
|
|
|
|
|
|
|
|
|
timesteps = schedule.ddim_timesteps
|
|
|
|
timesteps = schedule.ddim_timesteps[:t_start]
|
|
|
|
|
|
|
|
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
old_eps = []
|
|
|
|
old_eps = []
|
|
|
|
noisy_latent = initial_latent
|
|
|
|
noisy_latent = initial_latent
|
|
|
|
|
|
|
|
mask_noise = None
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|
|
|
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
|
|
|
|
|
|
|
|
noisy_latent.device
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
index = total_steps - i - 1
|
|
|
|
index = total_steps - i - 1
|
|
|
@ -81,8 +87,18 @@ class PLMSSampler:
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
if mask is not None:
|
|
|
|
assert orig_latent is not None
|
|
|
|
assert orig_latent is not None
|
|
|
|
img_orig = self.model.q_sample(orig_latent, ts)
|
|
|
|
xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise)
|
|
|
|
noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent
|
|
|
|
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
|
|
|
|
|
|
|
|
# this helps prevent the weird disjointed images that can happen with masking
|
|
|
|
|
|
|
|
hint_strength = 0.8
|
|
|
|
|
|
|
|
if i < 2:
|
|
|
|
|
|
|
|
xdec_orig_with_hints = (
|
|
|
|
|
|
|
|
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
xdec_orig_with_hints = xdec_orig
|
|
|
|
|
|
|
|
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
|
|
|
|
|
|
|
|
log_latent(noisy_latent, f"x_dec {ts}")
|
|
|
|
|
|
|
|
|
|
|
|
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
|
|
|
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
@ -202,7 +218,6 @@ class PLMSSampler:
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
|
|
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
|
|
|
# replace with ddpm.q_sample?
|
|
|
|
|
|
|
|
# fast, but does not allow for exact reconstruction
|
|
|
|
# fast, but does not allow for exact reconstruction
|
|
|
|
# t serves as an index to gather the correct alphas
|
|
|
|
# t serves as an index to gather the correct alphas
|
|
|
|
t = t.clamp(0, 1000)
|
|
|
|
t = t.clamp(0, 1000)
|
|
|
@ -216,84 +231,3 @@ class PLMSSampler:
|
|
|
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
|
|
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
|
|
|
* noise
|
|
|
|
* noise
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
|
|
|
def decode(
|
|
|
|
|
|
|
|
self,
|
|
|
|
|
|
|
|
neutral_conditioning,
|
|
|
|
|
|
|
|
positive_conditioning,
|
|
|
|
|
|
|
|
guidance_scale,
|
|
|
|
|
|
|
|
schedule,
|
|
|
|
|
|
|
|
initial_latent=None,
|
|
|
|
|
|
|
|
t_start=None,
|
|
|
|
|
|
|
|
temperature=1.0,
|
|
|
|
|
|
|
|
mask=None,
|
|
|
|
|
|
|
|
orig_latent=None,
|
|
|
|
|
|
|
|
noise=None,
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
timesteps = schedule.ddim_timesteps[:t_start]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_dec = initial_latent
|
|
|
|
|
|
|
|
old_eps = []
|
|
|
|
|
|
|
|
log_latent(x_dec, "x_dec")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# not sure what the downside of using the same noise throughout the process would be...
|
|
|
|
|
|
|
|
# seems to work fine. maybe it runs faster?
|
|
|
|
|
|
|
|
noise = (
|
|
|
|
|
|
|
|
torch.randn_like(x_dec, device="cpu").to(x_dec.device)
|
|
|
|
|
|
|
|
if noise is None
|
|
|
|
|
|
|
|
else noise
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
|
|
|
|
index = total_steps - i - 1
|
|
|
|
|
|
|
|
ts = torch.full(
|
|
|
|
|
|
|
|
(initial_latent.shape[0],),
|
|
|
|
|
|
|
|
step,
|
|
|
|
|
|
|
|
device=initial_latent.device,
|
|
|
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
ts_next = torch.full(
|
|
|
|
|
|
|
|
(initial_latent.shape[0],),
|
|
|
|
|
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
|
|
|
|
|
|
|
device=self.device,
|
|
|
|
|
|
|
|
dtype=torch.long,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
|
|
|
assert orig_latent is not None
|
|
|
|
|
|
|
|
xdec_orig = self.model.q_sample(orig_latent, ts, noise)
|
|
|
|
|
|
|
|
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
|
|
|
|
|
|
|
|
# this helps prevent the weird disjointed images that can happen with masking
|
|
|
|
|
|
|
|
hint_strength = 0.8
|
|
|
|
|
|
|
|
if i < 2:
|
|
|
|
|
|
|
|
xdec_orig_with_hints = (
|
|
|
|
|
|
|
|
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
xdec_orig_with_hints = xdec_orig
|
|
|
|
|
|
|
|
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
|
|
|
|
|
|
|
log_latent(x_dec, f"x_dec {ts}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
|
|
|
|
|
|
|
|
noisy_latent=x_dec,
|
|
|
|
|
|
|
|
guidance_scale=guidance_scale,
|
|
|
|
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
|
|
|
|
time_encoding=ts,
|
|
|
|
|
|
|
|
schedule=schedule,
|
|
|
|
|
|
|
|
index=index,
|
|
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
|
|
old_eps=old_eps,
|
|
|
|
|
|
|
|
t_next=ts_next,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old_eps.append(noise_prediction)
|
|
|
|
|
|
|
|
if len(old_eps) >= 4:
|
|
|
|
|
|
|
|
old_eps.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_latent(x_dec, f"x_dec {i}")
|
|
|
|
|
|
|
|
log_latent(pred_x0, f"pred_x0 {i}")
|
|
|
|
|
|
|
|
return x_dec
|
|
|
|
|
|
|
|