imaginAIry/imaginairy/samplers/plms.py

236 lines
8.0 KiB
Python
Raw Normal View History

2022-09-24 07:29:45 +00:00
# pylama:ignore=W0613
import logging
2022-09-08 03:59:30 +00:00
import numpy as np
import torch
2022-09-08 03:59:30 +00:00
from tqdm import tqdm
from imaginairy.log_utils import increment_step, log_latent
2022-10-13 06:45:08 +00:00
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import (
ImageSampler,
NoiseSchedule,
SamplerName,
get_noise_prediction,
mask_blend,
)
2022-09-08 03:59:30 +00:00
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
2022-09-08 03:59:30 +00:00
class PLMSSampler(ImageSampler):
2022-10-13 05:32:17 +00:00
"""
2023-01-02 04:14:22 +00:00
probabilistic least-mean-squares.
2022-10-13 05:32:17 +00:00
Provenance:
https://github.com/CompVis/latent-diffusion/commit/f0c4e092c156986e125f48c61a0edd38ba8ad059
https://arxiv.org/abs/2202.09778
https://github.com/luping-liu/PNDM
"""
short_name = SamplerName.PLMS
name = "probabilistic least-mean-squares sampler"
default_steps = 40
2022-09-08 03:59:30 +00:00
@torch.no_grad()
def sample(
self,
num_steps,
2022-09-08 03:59:30 +00:00
shape,
2022-10-13 05:32:17 +00:00
neutral_conditioning,
positive_conditioning,
guidance_scale=1.0,
batch_size=1,
2022-09-08 03:59:30 +00:00
mask=None,
orig_latent=None,
2022-09-08 03:59:30 +00:00
temperature=1.0,
noise_dropout=0.0,
initial_latent=None,
t_start=None,
quantize_denoised=False,
2022-09-08 03:59:30 +00:00
**kwargs,
):
# if positive_conditioning.shape[0] != batch_size:
# raise ValueError(
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
# )
2022-09-08 03:59:30 +00:00
2022-10-13 06:45:08 +00:00
schedule = NoiseSchedule(
2022-10-13 05:32:17 +00:00
model_num_timesteps=self.model.num_timesteps,
ddim_num_steps=num_steps,
2022-10-13 05:32:17 +00:00
model_alphas_cumprod=self.model.alphas_cumprod,
ddim_discretize="uniform",
2022-09-08 03:59:30 +00:00
)
2022-10-13 05:32:17 +00:00
if initial_latent is None:
2022-10-13 05:32:17 +00:00
initial_latent = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial latent")
2022-10-13 05:32:17 +00:00
timesteps = schedule.ddim_timesteps[:t_start]
2022-09-08 03:59:30 +00:00
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
2022-09-08 03:59:30 +00:00
old_eps = []
2022-10-13 05:32:17 +00:00
noisy_latent = initial_latent
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
noisy_latent.device
)
2022-09-08 03:59:30 +00:00
2022-10-13 05:32:17 +00:00
for i, step in enumerate(tqdm(time_range, total=total_steps)):
2022-09-08 03:59:30 +00:00
index = total_steps - i - 1
2022-10-13 05:32:17 +00:00
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
2022-09-08 03:59:30 +00:00
ts_next = torch.full(
(batch_size,),
2022-09-08 03:59:30 +00:00
time_range[min(i + 1, len(time_range) - 1)],
2022-10-13 05:32:17 +00:00
device=self.device,
2022-09-08 03:59:30 +00:00
dtype=torch.long,
)
if mask is not None:
2022-10-13 08:31:25 +00:00
noisy_latent = mask_blend(
noisy_latent=noisy_latent,
orig_latent=orig_latent,
mask=mask,
mask_noise=mask_noise,
ts=ts,
model=self.model,
)
2022-09-08 03:59:30 +00:00
2022-10-13 05:32:17 +00:00
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
noisy_latent=noisy_latent,
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
guidance_scale=guidance_scale,
time_encoding=ts,
schedule=schedule,
2022-09-08 03:59:30 +00:00
index=index,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
old_eps=old_eps,
t_next=ts_next,
)
2022-10-13 05:32:17 +00:00
old_eps.append(noise_pred)
2022-09-08 03:59:30 +00:00
if len(old_eps) >= 4:
old_eps.pop(0)
2022-10-13 05:32:17 +00:00
log_latent(noisy_latent, "noisy_latent")
log_latent(predicted_latent, "predicted_latent")
increment_step()
2022-10-13 05:32:17 +00:00
return noisy_latent
2022-09-08 03:59:30 +00:00
@torch.no_grad()
def p_sample_plms(
self,
noisy_latent,
2022-10-13 05:32:17 +00:00
neutral_conditioning,
positive_conditioning,
2022-10-13 05:32:17 +00:00
guidance_scale,
time_encoding,
2022-10-13 06:45:08 +00:00
schedule: NoiseSchedule,
2022-09-08 03:59:30 +00:00
index,
repeat_noise=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
old_eps=None,
t_next=None,
):
2022-10-13 05:32:17 +00:00
assert guidance_scale >= 1
noise_pred = get_noise_prediction(
denoise_func=self.model.apply_model,
noisy_latent=noisy_latent,
time_encoding=time_encoding,
2022-10-13 05:32:17 +00:00
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
2022-10-13 05:32:17 +00:00
signal_amplification=guidance_scale,
2022-09-08 03:59:30 +00:00
)
2022-10-13 05:32:17 +00:00
batch_size = noisy_latent.shape[0]
2022-09-08 03:59:30 +00:00
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
alpha_at_t = torch.full(
(batch_size, 1, 1, 1), schedule.ddim_alphas[index], device=self.device
)
alpha_prev_at_t = torch.full(
(batch_size, 1, 1, 1),
schedule.ddim_alphas_prev[index],
device=self.device,
)
sigma_t = torch.full(
(batch_size, 1, 1, 1), schedule.ddim_sigmas[index], device=self.device
)
2022-09-08 03:59:30 +00:00
sqrt_one_minus_at = torch.full(
(batch_size, 1, 1, 1),
schedule.ddim_sqrt_one_minus_alphas[index],
device=self.device,
2022-09-08 03:59:30 +00:00
)
# current prediction for x_0
pred_x0 = (noisy_latent - sqrt_one_minus_at * e_t) / alpha_at_t.sqrt()
2022-09-08 03:59:30 +00:00
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - alpha_prev_at_t - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(noisy_latent.shape, self.device, repeat_noise)
* temperature
)
2022-09-08 03:59:30 +00:00
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = alpha_prev_at_t.sqrt() * pred_x0 + dir_xt + noise
2022-09-08 03:59:30 +00:00
return x_prev, pred_x0
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
2022-10-13 05:32:17 +00:00
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_pred, index)
e_t_next = get_noise_prediction(
denoise_func=self.model.apply_model,
noisy_latent=x_prev,
time_encoding=t_next,
2022-10-13 05:32:17 +00:00
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
2022-10-13 05:32:17 +00:00
signal_amplification=guidance_scale,
)
2022-10-13 05:32:17 +00:00
e_t_prime = (noise_pred + e_t_next) / 2
2022-09-08 03:59:30 +00:00
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
2022-10-13 05:32:17 +00:00
e_t_prime = (3 * noise_pred - old_eps[-1]) / 2
2022-09-08 03:59:30 +00:00
elif len(old_eps) == 2:
# 3rd order Pseudo Linear Multistep (Adams-Bashforth)
2022-10-13 05:32:17 +00:00
e_t_prime = (23 * noise_pred - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
2022-09-08 03:59:30 +00:00
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
2022-10-13 05:32:17 +00:00
55 * noise_pred - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
2022-09-08 03:59:30 +00:00
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
log_latent(x_prev, "x_prev")
log_latent(pred_x0, "pred_x0")
2022-09-08 03:59:30 +00:00
2022-10-13 05:32:17 +00:00
return x_prev, pred_x0, noise_pred
@torch.no_grad()
def noise_an_image(self, init_latent, t, schedule, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
return (
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
* noise
)