refactor: begin to standardize samplers

pull/51/head
Bryce 2 years ago committed by Bryce Drennan
parent 62e4e9cc9d
commit 9ba302a5f4

@ -26,6 +26,7 @@ from imaginairy.img_log import (
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler
from imaginairy.samplers.plms import PLMSSchedule
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_group_norm,
@ -208,6 +209,7 @@ def imagine(
log_conditioning(c, "positive conditioning")
shape = [
1,
latent_channels,
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
@ -228,9 +230,6 @@ def imagine(
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength)
sampler.make_schedule(
ddim_num_steps=prompt.steps, ddim_eta=ddim_eta
)
try:
init_image = pillow_fit_image_within(
prompt.init_image,
@ -284,24 +283,35 @@ def imagine(
# encode (scaled latent)
seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
schedule = PLMSSchedule(
ddpm_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
alphas_cumprod=model.alphas_cumprod,
alphas_cumprod_prev=model.alphas_cumprod_prev,
betas=model.betas,
ddim_discretize="uniform",
ddim_eta=0.0,
)
if generation_strength >= 1:
# prompt strength gets converted to time encodings,
# which means you can't get to true 0 without this hack
# (or setting steps=1000)
z_enc = noise
else:
z_enc = sampler.stochastic_encode(
z_enc = sampler.noise_an_image(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(z_enc, "z_enc")
# decode it
samples = sampler.decode(
x_latent=z_enc,
initial_latent=z_enc,
cond=c,
t_start=t_enc,
schedule=schedule,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,

@ -56,7 +56,7 @@ def get_img_mask(
mask[mask >= 0.5] = 1
log_img(mask, f"mask threshold {0.5}")
mask_np = mask.cpu().numpy()
mask_np = mask.to(torch.float32).cpu().numpy()
smoother_strength = 2
# grow the mask area to make sure we've masked the thing we care about
for _ in range(smoother_strength):

@ -2,7 +2,7 @@
import torch
from torch import nn
from imaginairy.utils import get_device
from imaginairy.img_log import log_latent
SAMPLER_TYPE_OPTIONS = [
"plms",
@ -51,11 +51,19 @@ class CFGDenoiser(nn.Module):
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
return self.inner_model(
noisy_latent_in, time_encoding_in, cond=conditioning_in
)
denoised = get_noise_prediction(
denoise_func=_wrapper,
noisy_latent=x,
time_encoding=sigma,
neutral_conditioning=uncond,
positive_conditioning=cond,
signal_amplification=cond_scale,
)
if mask is not None:
assert orig_latent is not None
@ -65,51 +73,37 @@ class CFGDenoiser(nn.Module):
return denoised
class DiffusionSampler:
"""
wip
def ensure_4_dim(t: torch.Tensor):
if len(t.shape) == 3:
t = t.unsqueeze(dim=0)
return t
hope to enforce an api upon samplers
"""
def __init__(self, noise_prediction_model, sampler_func, device=get_device()):
self.noise_prediction_model = noise_prediction_model
self.cfg_noise_prediction_model = CFGDenoiser(noise_prediction_model)
self.sampler_func = sampler_func
self.device = device
def zzsample(
self,
num_steps,
text_conditioning,
batch_size,
shape,
unconditional_guidance_scale,
unconditional_conditioning,
eta,
initial_noise_tensor=None,
img_callback=None,
):
size = (batch_size, *shape)
initial_noise_tensor = (
torch.randn(size, device="cpu").to(get_device())
if initial_noise_tensor is None
else initial_noise_tensor
)
sigmas = self.noise_prediction_model.get_sigmas(num_steps)
x = initial_noise_tensor * sigmas[0]
samples = self.sampler_func(
self.cfg_noise_prediction_model,
x,
sigmas,
extra_args={
"cond": text_conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
},
disable=False,
)
def get_noise_prediction(
denoise_func,
noisy_latent,
time_encoding,
neutral_conditioning,
positive_conditioning,
signal_amplification=7.5,
):
noisy_latent = ensure_4_dim(noisy_latent)
noisy_latent_in = torch.cat([noisy_latent] * 2)
time_encoding_in = torch.cat([time_encoding] * 2)
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
pred_noise_neutral, pred_noise_positive = denoise_func(
noisy_latent_in, time_encoding_in, conditioning_in
).chunk(2)
amplified_noise_pred = signal_amplification * (
pred_noise_positive - pred_noise_neutral
)
pred_noise = pred_noise_neutral + amplified_noise_pred
log_latent(pred_noise_neutral, "neutral noise prediction")
log_latent(pred_noise_positive, "positive noise prediction")
log_latent(pred_noise, "noise prediction")
return samples, None
return pred_noise

@ -1,5 +1,4 @@
# pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging
import numpy as np
@ -13,41 +12,19 @@ from imaginairy.modules.diffusion.util import (
make_ddim_timesteps,
noise_like,
)
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class DDIMSampler:
"""
Denoising Diffusion Implicit Models
https://arxiv.org/abs/2010.02502
"""
def __init__(self, model):
self.model = model
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
buffers = self._make_schedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
model_betas=self.model.betas,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
ddim_num_steps=ddim_num_steps,
ddim_discretize=ddim_discretize,
ddim_eta=ddim_eta,
device=self.model.device,
)
for k, v in buffers.items():
setattr(self, k, v)
@staticmethod
def _make_schedule(
class DDIMSchedule:
def __init__(
self,
model_num_timesteps,
model_alphas_cumprod,
model_betas,
model_alphas_cumprod_prev,
model_betas,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
@ -71,41 +48,37 @@ class DDIMSampler:
ddim_timesteps=ddim_timesteps,
eta=ddim_eta,
)
buffers = {
"ddim_timesteps": ddim_timesteps,
"betas": to_torch(model_betas),
"alphas_cumprod": to_torch(alphas_cumprod),
"alphas_cumprod_prev": to_torch(model_alphas_cumprod_prev),
# calculations for diffusion q(x_t | x_{t-1}) and others
"sqrt_alphas_cumprod": to_torch(np.sqrt(alphas_cumprod.cpu())),
"sqrt_one_minus_alphas_cumprod": to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
),
"log_one_minus_alphas_cumprod": to_torch(
np.log(1.0 - alphas_cumprod.cpu())
),
"sqrt_recip_alphas_cumprod": to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
"sqrt_recipm1_alphas_cumprod": to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
),
"ddim_sigmas": ddim_sigmas.to(torch.float32).to(device),
"ddim_alphas": ddim_alphas.to(torch.float32).to(device),
"ddim_alphas_prev": ddim_alphas_prev,
"ddim_sqrt_one_minus_alphas": np.sqrt(1.0 - ddim_alphas)
.to(torch.float32)
.to(device),
}
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - buffers["alphas_cumprod_prev"])
/ (1 - buffers["alphas_cumprod"])
* (1 - buffers["alphas_cumprod"] / buffers["alphas_cumprod_prev"])
self.ddim_timesteps = ddim_timesteps
self.betas = to_torch(model_betas)
self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
)
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
)
buffers[
"ddim_sigmas_for_original_num_steps"
] = sigmas_for_original_sampling_steps
return buffers
class DDIMSampler:
"""
Denoising Diffusion Implicit Models
https://arxiv.org/abs/2010.02502
"""
def __init__(self, model):
self.model = model
@torch.no_grad()
def sample(
@ -123,31 +96,30 @@ class DDIMSampler:
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
x_T=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
**kwargs,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
):
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
logger.warning(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
if conditioning.shape[0] != batch_size:
logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
schedule = DDIMSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
model_betas=self.model.betas,
ddim_num_steps=num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
)
samples = self.ddim_sampling(
conditioning,
shape=(batch_size, *shape),
shape=shape,
schedule=schedule,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@ -155,8 +127,6 @@ class DDIMSampler:
x0=x0,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
@ -168,6 +138,7 @@ class DDIMSampler:
self,
cond,
shape,
schedule,
x_T=None,
callback=None,
timesteps=None,
@ -177,8 +148,6 @@ class DDIMSampler:
img_callback=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
@ -192,16 +161,16 @@ class DDIMSampler:
log_latent(img, "initial noise")
if timesteps is None:
timesteps = self.ddim_timesteps
timesteps = schedule.ddim_timesteps
else:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
min(timesteps / schedule.ddim_timesteps.shape[0], 1)
* schedule.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
timesteps = schedule.ddim_timesteps[:subset_end]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
@ -225,6 +194,7 @@ class DDIMSampler:
cond,
ts,
index=index,
schedule=schedule,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
@ -245,6 +215,7 @@ class DDIMSampler:
c,
t,
index,
schedule,
repeat_noise=False,
quantize_denoised=False,
temperature=1.0,
@ -254,26 +225,26 @@ class DDIMSampler:
loss_function=None,
):
assert unconditional_guidance_scale >= 1
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
# with torch.no_grad():
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(
2
)
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (
noise_pred - noise_pred_uncond
noise_pred = get_noise_prediction(
denoise_func=self.model.apply_model,
noisy_latent=x,
time_encoding=t,
neutral_conditioning=unconditional_conditioning,
positive_conditioning=c,
signal_amplification=unconditional_guidance_scale,
)
b = x.shape[0]
log_latent(noise_pred, "noise prediction")
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=x.device)
a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=x.device)
sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=x.device)
a_t = torch.full((b, 1, 1, 1), schedule.ddim_alphas[index], device=x.device)
a_prev = torch.full(
(b, 1, 1, 1), schedule.ddim_alphas_prev[index], device=x.device
)
sigma_t = torch.full((b, 1, 1, 1), schedule.ddim_sigmas[index], device=x.device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index], device=x.device
(b, 1, 1, 1), schedule.ddim_sqrt_one_minus_alphas[index], device=x.device
)
return self._p_sample_ddim_formula(
x,
@ -310,12 +281,11 @@ class DDIMSampler:
return x_prev, pred_x0
@torch.no_grad()
def stochastic_encode(self, init_latent, t, noise=None):
# fast, but does not allow for exact reconstruction
def noise_an_image(self, init_latent, t, schedule, noise=None):
# t serves as an index to gather the correct alphas
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
@ -328,31 +298,34 @@ class DDIMSampler:
@torch.no_grad()
def decode(
self,
x_latent,
initial_latent,
cond,
t_start,
schedule,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
img_callback=None,
score_corrector=None,
temperature=1.0,
mask=None,
orig_latent=None,
):
timesteps = self.ddim_timesteps[:t_start]
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
logger.debug(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
x_dec = initial_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
(initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
)
if mask is not None:
@ -374,17 +347,12 @@ class DDIMSampler:
x_dec,
cond,
ts,
schedule=schedule,
index=index,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
temperature=temperature,
)
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
log_latent(x_dec, f"x_dec {i}")
log_latent(pred_x0, f"pred_x0 {i}")

@ -8,11 +8,15 @@ from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
class StandardCompVisDenoiser(CompVisDenoiser):
def apply_model(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class KDiffusionSampler:
def __init__(self, model, sampler_name):
self.model = model
self.cv_denoiser = CompVisDenoiser(model)
# self.cfg_denoiser = CompVisDenoiser(self.cv_denoiser)
self.cv_denoiser = StandardCompVisDenoiser(model)
self.sampler_name = sampler_name
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
@ -28,10 +32,8 @@ class KDiffusionSampler:
initial_noise_tensor=None,
img_callback=None,
):
size = (batch_size, *shape)
initial_noise_tensor = (
torch.randn(size, device="cpu").to(get_device())
torch.randn(shape, device="cpu").to(get_device())
if initial_noise_tensor is None
else initial_noise_tensor
)

@ -1,5 +1,4 @@
# pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging
import numpy as np
@ -13,65 +12,52 @@ from imaginairy.modules.diffusion.util import (
make_ddim_timesteps,
noise_like,
)
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class PLMSSampler:
"""probabilistic least-mean-squares"""
def __init__(self, model):
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.device_available = get_device()
self.ddim_timesteps = None
def register_buffer(self, name, attr):
if isinstance(attr, torch.Tensor):
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
class PLMSSchedule:
def __init__(
self,
ddpm_num_timesteps, # 1000?
ddim_num_steps, # prompt.steps?
alphas_cumprod,
alphas_cumprod_prev,
betas,
ddim_discretize="uniform",
ddim_eta=0.0,
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
)
alphas_cumprod = self.model.alphas_cumprod
device = get_device()
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
alphas_cumprod.shape[0] == ddpm_num_timesteps
), "alphas have to be defined for each timestep"
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
return x.clone().detach().to(torch.float32).to(device)
self.betas = to_torch(betas)
self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=ddpm_num_timesteps,
)
# ddim sampling parameters
@ -80,19 +66,21 @@ class PLMSSampler:
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(torch.device(device))
)
class PLMSSampler:
"""probabilistic least-mean-squares"""
def __init__(self, model):
self.model = model
self.device = get_device()
@torch.no_grad()
def sample(
self,
@ -101,145 +89,90 @@ class PLMSSampler:
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
orig_latent=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
x_T=None,
initial_latent=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
timesteps=None,
quantize_denoised=False,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
logger.warning(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
if conditioning.shape[0] != batch_size:
logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
samples = self.plms_sampling(
conditioning,
(batch_size, *shape),
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
schedule = PLMSSchedule(
ddpm_num_timesteps=self.model.num_timesteps,
ddim_num_steps=num_steps,
alphas_cumprod=self.model.alphas_cumprod,
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
betas=self.model.betas,
ddim_discretize="uniform",
ddim_eta=0.0,
)
return samples
@torch.no_grad()
def plms_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device="cpu").to(device)
else:
img = x_T
log_latent(img, "initial img")
device = self.device
# batch_size = shape[0]
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(device)
log_latent(initial_latent, "initial latent")
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
timesteps = schedule.ddim_timesteps
elif timesteps is not None:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
min(timesteps / schedule.ddim_timesteps.shape[0], 1)
* schedule.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
timesteps = schedule.ddim_timesteps[:subset_end]
time_range = (
list(reversed(range(0, timesteps)))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
logger.debug(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps)
old_eps = []
img = initial_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
ts_next = torch.full(
(b,),
(batch_size,),
time_range[min(i + 1, len(time_range) - 1)],
device=device,
dtype=torch.long,
)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
assert orig_latent is not None
img_orig = self.model.q_sample(orig_latent, ts)
img = img_orig * mask + (1.0 - mask) * img
img, pred_x0, e_t = self.p_sample_plms(
img, pred_x0, noise_prediction = self.p_sample_plms(
img,
cond,
conditioning,
ts,
schedule=schedule,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
)
old_eps.append(e_t)
old_eps.append(noise_prediction)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
@ -253,119 +186,108 @@ class PLMSSampler:
@torch.no_grad()
def p_sample_plms(
self,
x,
c,
t,
noisy_latent,
positive_conditioning,
time_encoding,
schedule: PLMSSchedule,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if (
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
log_latent(e_t_uncond, "noise pred uncond")
log_latent(e_t, "noise pred cond")
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
log_latent(e_t, "noise pred combined")
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
batch_size = noisy_latent.shape[0]
noise_prediction = get_noise_prediction(
denoise_func=self.model.apply_model,
noisy_latent=noisy_latent,
time_encoding=time_encoding,
neutral_conditioning=unconditional_conditioning,
positive_conditioning=positive_conditioning,
signal_amplification=unconditional_guidance_scale,
)
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
alpha_at_t = torch.full(
(batch_size, 1, 1, 1), schedule.ddim_alphas[index], device=self.device
)
alpha_prev_at_t = torch.full(
(batch_size, 1, 1, 1),
schedule.ddim_alphas_prev[index],
device=self.device,
)
sigma_t = torch.full(
(batch_size, 1, 1, 1), schedule.ddim_sigmas[index], device=self.device
)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
(batch_size, 1, 1, 1),
schedule.ddim_sqrt_one_minus_alphas[index],
device=self.device,
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
pred_x0 = (noisy_latent - sqrt_one_minus_at * e_t) / alpha_at_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
dir_xt = (1.0 - alpha_prev_at_t - sigma_t**2).sqrt() * e_t
noise = (
sigma_t
* noise_like(noisy_latent.shape, self.device, repeat_noise)
* temperature
)
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
x_prev = alpha_prev_at_t.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_prediction, index)
e_t_next = get_noise_prediction(
denoise_func=self.model.apply_model,
noisy_latent=x_prev,
time_encoding=t_next,
neutral_conditioning=unconditional_conditioning,
positive_conditioning=positive_conditioning,
signal_amplification=unconditional_guidance_scale,
)
e_t_prime = (noise_prediction + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
e_t_prime = (3 * noise_prediction - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
e_t_prime = (
23 * noise_prediction - 16 * old_eps[-1] + 5 * old_eps[-2]
) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
55 * noise_prediction
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
log_latent(x_prev, "x_prev")
log_latent(pred_x0, "pred_x0")
return x_prev, pred_x0, e_t
return x_prev, pred_x0, noise_prediction
@torch.no_grad()
def stochastic_encode(self, init_latent, t, noise=None):
def noise_an_image(self, init_latent, t, schedule, noise=None):
# replace with ddpm.q_sample?
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
@ -378,26 +300,26 @@ class PLMSSampler:
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
schedule,
initial_latent=None,
t_start=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
img_callback=None,
score_corrector=None,
temperature=1.0,
mask=None,
orig_latent=None,
noise=None,
):
timesteps = self.ddim_timesteps[:t_start]
device = self.device
timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
x_dec = x_latent
iterator = tqdm(time_range, desc="PLMS img2img", total=total_steps)
x_dec = initial_latent
old_eps = []
log_latent(x_dec, "x_dec")
@ -411,12 +333,15 @@ class PLMSSampler:
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
(initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
)
ts_next = torch.full(
(x_latent.shape[0],),
(initial_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)],
device=x_latent.device,
device=device,
dtype=torch.long,
)
@ -435,10 +360,11 @@ class PLMSSampler:
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, f"x_dec {ts}")
x_dec, pred_x0, e_t = self.p_sample_plms(
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
x_dec,
cond,
ts,
schedule=schedule,
index=index,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
@ -446,14 +372,8 @@ class PLMSSampler:
old_eps=old_eps,
t_next=ts_next,
)
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
old_eps.append(e_t)
old_eps.append(noise_prediction)
if len(old_eps) >= 4:
old_eps.pop(0)

@ -47,6 +47,11 @@ def pre_setup():
yield
@pytest.fixture(autouse=True)
def reset_get_device():
get_device.cache_clear()
@pytest.fixture()
def filename_base_for_outputs(request):
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"

@ -54,7 +54,6 @@ def experiment_step_repeats():
embedder.to(get_device())
sampler = DDIMSampler(model)
sampler.make_schedule(1000)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
init_image, _, _ = pillow_img_to_torch_image(
@ -89,7 +88,9 @@ def experiment_step_repeats():
# noise_pred = model.apply_model(init_latent, t, neutral_embedding)
# log_latent(noise_pred, "noise prediction")
for _ in range(100):
x_prev, pred_x0 = sampler.p_sample_ddim(x_prev, neutral_embedding, t, index)
x_prev, pred_x0 = sampler.p_sample_ddim( # noqa
x_prev, neutral_embedding, t, index
)
log_latent(pred_x0, "pred_x0")
x_prev = pred_x0

@ -12,7 +12,6 @@ from imaginairy.utils import (
get_hardware_description,
get_obj_from_str,
instantiate_from_config,
platform_appropriate_autocast,
)
@ -79,6 +78,7 @@ def test_instantiate_from_config():
instantiate_from_config(config)
def test_platform_appropriate_autocast():
with platform_appropriate_autocast("autocast"):
pass
#
# def test_platform_appropriate_autocast():
# with platform_appropriate_autocast("autocast"):
# pass

Loading…
Cancel
Save