2022-09-24 07:29:45 +00:00
|
|
|
# pylama:ignore=W0613
|
2022-09-09 04:51:25 +00:00
|
|
|
import logging
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
import numpy as np
|
2022-09-09 04:51:25 +00:00
|
|
|
import torch
|
2022-09-08 03:59:30 +00:00
|
|
|
from tqdm import tqdm
|
|
|
|
|
2022-11-13 03:24:03 +00:00
|
|
|
from imaginairy.log_utils import increment_step, log_latent
|
2022-10-13 06:45:08 +00:00
|
|
|
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
2022-11-26 22:52:28 +00:00
|
|
|
from imaginairy.samplers.base import (
|
|
|
|
ImageSampler,
|
|
|
|
NoiseSchedule,
|
|
|
|
SamplerName,
|
|
|
|
get_noise_prediction,
|
|
|
|
mask_blend,
|
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-11-26 22:52:28 +00:00
|
|
|
class PLMSSampler(ImageSampler):
|
2022-10-13 05:32:17 +00:00
|
|
|
"""
|
2023-01-02 04:14:22 +00:00
|
|
|
probabilistic least-mean-squares.
|
2022-10-13 05:32:17 +00:00
|
|
|
|
|
|
|
Provenance:
|
|
|
|
https://github.com/CompVis/latent-diffusion/commit/f0c4e092c156986e125f48c61a0edd38ba8ad059
|
|
|
|
https://arxiv.org/abs/2202.09778
|
|
|
|
https://github.com/luping-liu/PNDM
|
|
|
|
"""
|
2022-10-06 04:43:00 +00:00
|
|
|
|
2022-11-26 22:52:28 +00:00
|
|
|
short_name = SamplerName.PLMS
|
|
|
|
name = "probabilistic least-mean-squares sampler"
|
|
|
|
default_steps = 40
|
2022-10-06 04:43:00 +00:00
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def sample(
|
|
|
|
self,
|
2022-09-14 07:40:25 +00:00
|
|
|
num_steps,
|
2022-09-08 03:59:30 +00:00
|
|
|
shape,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
guidance_scale=1.0,
|
|
|
|
batch_size=1,
|
2022-09-08 03:59:30 +00:00
|
|
|
mask=None,
|
2022-10-06 04:43:00 +00:00
|
|
|
orig_latent=None,
|
2022-09-08 03:59:30 +00:00
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
2022-10-06 04:43:00 +00:00
|
|
|
initial_latent=None,
|
2022-10-13 08:02:37 +00:00
|
|
|
t_start=None,
|
2022-10-06 04:43:00 +00:00
|
|
|
quantize_denoised=False,
|
2022-09-08 03:59:30 +00:00
|
|
|
**kwargs,
|
|
|
|
):
|
2022-10-23 21:46:45 +00:00
|
|
|
# if positive_conditioning.shape[0] != batch_size:
|
|
|
|
# raise ValueError(
|
|
|
|
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
# )
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-10-13 06:45:08 +00:00
|
|
|
schedule = NoiseSchedule(
|
2022-10-13 05:32:17 +00:00
|
|
|
model_num_timesteps=self.model.num_timesteps,
|
2022-10-06 04:43:00 +00:00
|
|
|
ddim_num_steps=num_steps,
|
2022-10-13 05:32:17 +00:00
|
|
|
model_alphas_cumprod=self.model.alphas_cumprod,
|
2022-10-06 04:43:00 +00:00
|
|
|
ddim_discretize="uniform",
|
2022-09-08 03:59:30 +00:00
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
if initial_latent is None:
|
2022-10-13 05:32:17 +00:00
|
|
|
initial_latent = torch.randn(shape, device="cpu").to(self.device)
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
log_latent(initial_latent, "initial latent")
|
2022-10-13 05:32:17 +00:00
|
|
|
|
2022-10-13 08:02:37 +00:00
|
|
|
timesteps = schedule.ddim_timesteps[:t_start]
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
time_range = np.flip(timesteps)
|
|
|
|
total_steps = timesteps.shape[0]
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
old_eps = []
|
2022-10-13 05:32:17 +00:00
|
|
|
noisy_latent = initial_latent
|
2022-10-13 08:02:37 +00:00
|
|
|
mask_noise = None
|
|
|
|
if mask is not None:
|
|
|
|
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
|
|
|
|
noisy_latent.device
|
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
2022-09-08 03:59:30 +00:00
|
|
|
index = total_steps - i - 1
|
2022-10-13 05:32:17 +00:00
|
|
|
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
|
2022-09-08 03:59:30 +00:00
|
|
|
ts_next = torch.full(
|
2022-10-06 04:43:00 +00:00
|
|
|
(batch_size,),
|
2022-09-08 03:59:30 +00:00
|
|
|
time_range[min(i + 1, len(time_range) - 1)],
|
2022-10-13 05:32:17 +00:00
|
|
|
device=self.device,
|
2022-09-08 03:59:30 +00:00
|
|
|
dtype=torch.long,
|
|
|
|
)
|
|
|
|
|
|
|
|
if mask is not None:
|
2022-10-13 08:31:25 +00:00
|
|
|
noisy_latent = mask_blend(
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
orig_latent=orig_latent,
|
|
|
|
mask=mask,
|
|
|
|
mask_noise=mask_noise,
|
|
|
|
ts=ts,
|
|
|
|
model=self.model,
|
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
neutral_conditioning=neutral_conditioning,
|
|
|
|
positive_conditioning=positive_conditioning,
|
|
|
|
guidance_scale=guidance_scale,
|
|
|
|
time_encoding=ts,
|
2022-10-06 04:43:00 +00:00
|
|
|
schedule=schedule,
|
2022-09-08 03:59:30 +00:00
|
|
|
index=index,
|
|
|
|
quantize_denoised=quantize_denoised,
|
|
|
|
temperature=temperature,
|
|
|
|
noise_dropout=noise_dropout,
|
|
|
|
old_eps=old_eps,
|
|
|
|
t_next=ts_next,
|
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
old_eps.append(noise_pred)
|
2022-09-08 03:59:30 +00:00
|
|
|
if len(old_eps) >= 4:
|
|
|
|
old_eps.pop(0)
|
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
log_latent(noisy_latent, "noisy_latent")
|
|
|
|
log_latent(predicted_latent, "predicted_latent")
|
2022-11-13 03:24:03 +00:00
|
|
|
increment_step()
|
2022-10-13 05:32:17 +00:00
|
|
|
|
|
|
|
return noisy_latent
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def p_sample_plms(
|
|
|
|
self,
|
2022-10-06 04:43:00 +00:00
|
|
|
noisy_latent,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning,
|
2022-10-06 04:43:00 +00:00
|
|
|
positive_conditioning,
|
2022-10-13 05:32:17 +00:00
|
|
|
guidance_scale,
|
2022-10-06 04:43:00 +00:00
|
|
|
time_encoding,
|
2022-10-13 06:45:08 +00:00
|
|
|
schedule: NoiseSchedule,
|
2022-09-08 03:59:30 +00:00
|
|
|
index,
|
|
|
|
repeat_noise=False,
|
|
|
|
quantize_denoised=False,
|
|
|
|
temperature=1.0,
|
|
|
|
noise_dropout=0.0,
|
|
|
|
old_eps=None,
|
|
|
|
t_next=None,
|
|
|
|
):
|
2022-10-13 05:32:17 +00:00
|
|
|
assert guidance_scale >= 1
|
|
|
|
noise_pred = get_noise_prediction(
|
2022-10-06 04:43:00 +00:00
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
time_encoding=time_encoding,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning=neutral_conditioning,
|
2022-10-06 04:43:00 +00:00
|
|
|
positive_conditioning=positive_conditioning,
|
2022-10-13 05:32:17 +00:00
|
|
|
signal_amplification=guidance_scale,
|
2022-09-08 03:59:30 +00:00
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
batch_size = noisy_latent.shape[0]
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
def get_x_prev_and_pred_x0(e_t, index):
|
|
|
|
# select parameters corresponding to the currently considered timestep
|
2022-10-06 04:43:00 +00:00
|
|
|
alpha_at_t = torch.full(
|
|
|
|
(batch_size, 1, 1, 1), schedule.ddim_alphas[index], device=self.device
|
|
|
|
)
|
|
|
|
alpha_prev_at_t = torch.full(
|
|
|
|
(batch_size, 1, 1, 1),
|
|
|
|
schedule.ddim_alphas_prev[index],
|
|
|
|
device=self.device,
|
|
|
|
)
|
|
|
|
sigma_t = torch.full(
|
|
|
|
(batch_size, 1, 1, 1), schedule.ddim_sigmas[index], device=self.device
|
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
sqrt_one_minus_at = torch.full(
|
2022-10-06 04:43:00 +00:00
|
|
|
(batch_size, 1, 1, 1),
|
|
|
|
schedule.ddim_sqrt_one_minus_alphas[index],
|
|
|
|
device=self.device,
|
2022-09-08 03:59:30 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# current prediction for x_0
|
2022-10-06 04:43:00 +00:00
|
|
|
pred_x0 = (noisy_latent - sqrt_one_minus_at * e_t) / alpha_at_t.sqrt()
|
2022-09-08 03:59:30 +00:00
|
|
|
if quantize_denoised:
|
|
|
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
|
|
# direction pointing to x_t
|
2022-10-06 04:43:00 +00:00
|
|
|
dir_xt = (1.0 - alpha_prev_at_t - sigma_t**2).sqrt() * e_t
|
|
|
|
noise = (
|
|
|
|
sigma_t
|
|
|
|
* noise_like(noisy_latent.shape, self.device, repeat_noise)
|
|
|
|
* temperature
|
|
|
|
)
|
2022-09-08 03:59:30 +00:00
|
|
|
if noise_dropout > 0.0:
|
|
|
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
2022-10-06 04:43:00 +00:00
|
|
|
x_prev = alpha_prev_at_t.sqrt() * pred_x0 + dir_xt + noise
|
2022-09-08 03:59:30 +00:00
|
|
|
return x_prev, pred_x0
|
|
|
|
|
|
|
|
if len(old_eps) == 0:
|
|
|
|
# Pseudo Improved Euler (2nd order)
|
2022-10-13 05:32:17 +00:00
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_pred, index)
|
2022-10-06 04:43:00 +00:00
|
|
|
e_t_next = get_noise_prediction(
|
|
|
|
denoise_func=self.model.apply_model,
|
|
|
|
noisy_latent=x_prev,
|
|
|
|
time_encoding=t_next,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning=neutral_conditioning,
|
2022-10-06 04:43:00 +00:00
|
|
|
positive_conditioning=positive_conditioning,
|
2022-10-13 05:32:17 +00:00
|
|
|
signal_amplification=guidance_scale,
|
2022-10-06 04:43:00 +00:00
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
e_t_prime = (noise_pred + e_t_next) / 2
|
2022-09-08 03:59:30 +00:00
|
|
|
elif len(old_eps) == 1:
|
|
|
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
2022-10-13 05:32:17 +00:00
|
|
|
e_t_prime = (3 * noise_pred - old_eps[-1]) / 2
|
2022-09-08 03:59:30 +00:00
|
|
|
elif len(old_eps) == 2:
|
2022-11-26 03:07:57 +00:00
|
|
|
# 3rd order Pseudo Linear Multistep (Adams-Bashforth)
|
2022-10-13 05:32:17 +00:00
|
|
|
e_t_prime = (23 * noise_pred - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
2022-09-08 03:59:30 +00:00
|
|
|
elif len(old_eps) >= 3:
|
|
|
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
|
|
e_t_prime = (
|
2022-10-13 05:32:17 +00:00
|
|
|
55 * noise_pred - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
2022-09-08 03:59:30 +00:00
|
|
|
) / 24
|
|
|
|
|
|
|
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
2022-09-20 15:42:00 +00:00
|
|
|
log_latent(x_prev, "x_prev")
|
|
|
|
log_latent(pred_x0, "pred_x0")
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
return x_prev, pred_x0, noise_pred
|
2022-09-20 15:42:00 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
2022-10-06 04:43:00 +00:00
|
|
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
2022-09-20 15:42:00 +00:00
|
|
|
# fast, but does not allow for exact reconstruction
|
|
|
|
# t serves as an index to gather the correct alphas
|
2022-09-24 21:41:25 +00:00
|
|
|
t = t.clamp(0, 1000)
|
2022-10-06 04:43:00 +00:00
|
|
|
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
|
|
|
|
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
|
2022-09-20 15:42:00 +00:00
|
|
|
|
|
|
|
if noise is None:
|
|
|
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
|
|
|
return (
|
|
|
|
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
|
|
|
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
|
|
|
* noise
|
|
|
|
)
|