2022-09-24 07:29:45 +00:00
|
|
|
# pylama:ignore=W0613
|
2022-10-14 03:49:48 +00:00
|
|
|
import logging
|
|
|
|
|
2022-10-13 06:45:08 +00:00
|
|
|
import numpy as np
|
2022-09-14 07:40:25 +00:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
|
2022-10-23 21:46:45 +00:00
|
|
|
from imaginairy.log_utils import log_latent
|
2022-10-13 06:45:08 +00:00
|
|
|
from imaginairy.modules.diffusion.util import (
|
2022-10-14 03:49:48 +00:00
|
|
|
extract_into_tensor,
|
2022-10-13 06:45:08 +00:00
|
|
|
make_ddim_sampling_parameters,
|
|
|
|
make_ddim_timesteps,
|
|
|
|
)
|
|
|
|
from imaginairy.utils import get_device
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-09-14 16:37:45 +00:00
|
|
|
SAMPLER_TYPE_OPTIONS = [
|
|
|
|
"plms",
|
|
|
|
"ddim",
|
2022-10-15 00:21:38 +00:00
|
|
|
"k_dpm_fast",
|
|
|
|
"k_dpm_adaptive",
|
2022-09-14 16:37:45 +00:00
|
|
|
"k_lms",
|
|
|
|
"k_dpm_2",
|
|
|
|
"k_dpm_2_a",
|
|
|
|
"k_euler",
|
|
|
|
"k_euler_a",
|
|
|
|
"k_heun",
|
|
|
|
]
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
_k_sampler_type_lookup = {
|
2022-10-15 00:21:38 +00:00
|
|
|
"k_dpm_fast": "dpm_fast",
|
|
|
|
"k_dpm_adaptive": "dpm_adaptive",
|
2022-09-14 07:40:25 +00:00
|
|
|
"k_dpm_2": "dpm_2",
|
|
|
|
"k_dpm_2_a": "dpm_2_ancestral",
|
|
|
|
"k_euler": "euler",
|
|
|
|
"k_euler_a": "euler_ancestral",
|
|
|
|
"k_heun": "heun",
|
|
|
|
"k_lms": "lms",
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def get_sampler(sampler_type, model):
|
2022-09-24 07:29:45 +00:00
|
|
|
from imaginairy.samplers.ddim import DDIMSampler # noqa
|
|
|
|
from imaginairy.samplers.kdiff import KDiffusionSampler # noqa
|
|
|
|
from imaginairy.samplers.plms import PLMSSampler # noqa
|
2022-09-17 19:24:27 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
sampler_type = sampler_type.lower()
|
|
|
|
if sampler_type == "plms":
|
|
|
|
return PLMSSampler(model)
|
2022-09-22 17:56:18 +00:00
|
|
|
if sampler_type == "ddim":
|
2022-09-14 07:40:25 +00:00
|
|
|
return DDIMSampler(model)
|
2022-09-22 17:56:18 +00:00
|
|
|
if sampler_type.startswith("k_"):
|
2022-09-14 07:40:25 +00:00
|
|
|
sampler_type = _k_sampler_type_lookup[sampler_type]
|
|
|
|
return KDiffusionSampler(model, sampler_type)
|
2022-09-22 17:56:18 +00:00
|
|
|
raise ValueError("invalid sampler_type")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CFGDenoiser(nn.Module):
|
2022-09-17 19:24:27 +00:00
|
|
|
"""
|
|
|
|
Conditional forward guidance wrapper
|
|
|
|
"""
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.inner_model = model
|
2022-10-14 03:49:48 +00:00
|
|
|
self.device = get_device()
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
def forward(
|
|
|
|
self,
|
|
|
|
x,
|
|
|
|
sigma,
|
|
|
|
uncond,
|
|
|
|
cond,
|
|
|
|
cond_scale,
|
|
|
|
mask=None,
|
|
|
|
mask_noise=None,
|
|
|
|
orig_latent=None,
|
|
|
|
):
|
2022-10-06 04:43:00 +00:00
|
|
|
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
|
|
|
|
return self.inner_model(
|
|
|
|
noisy_latent_in, time_encoding_in, cond=conditioning_in
|
|
|
|
)
|
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
if mask is not None:
|
|
|
|
assert orig_latent is not None
|
|
|
|
t = self.inner_model.sigma_to_t(sigma, quantize=True)
|
|
|
|
big_sigma = max(sigma, 1)
|
|
|
|
x = mask_blend(
|
|
|
|
noisy_latent=x,
|
|
|
|
orig_latent=orig_latent * big_sigma,
|
|
|
|
mask=mask,
|
|
|
|
mask_noise=mask_noise * big_sigma,
|
|
|
|
ts=t,
|
|
|
|
model=self.inner_model.inner_model,
|
|
|
|
)
|
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred = get_noise_prediction(
|
2022-10-06 04:43:00 +00:00
|
|
|
denoise_func=_wrapper,
|
|
|
|
noisy_latent=x,
|
|
|
|
time_encoding=sigma,
|
|
|
|
neutral_conditioning=uncond,
|
|
|
|
positive_conditioning=cond,
|
|
|
|
signal_amplification=cond_scale,
|
|
|
|
)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
return noise_pred
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
def ensure_4_dim(t: torch.Tensor):
|
|
|
|
if len(t.shape) == 3:
|
|
|
|
t = t.unsqueeze(dim=0)
|
|
|
|
return t
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
def get_noise_prediction(
|
|
|
|
denoise_func,
|
|
|
|
noisy_latent,
|
|
|
|
time_encoding,
|
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
signal_amplification=7.5,
|
|
|
|
):
|
|
|
|
noisy_latent = ensure_4_dim(noisy_latent)
|
|
|
|
|
|
|
|
noisy_latent_in = torch.cat([noisy_latent] * 2)
|
|
|
|
time_encoding_in = torch.cat([time_encoding] * 2)
|
2022-10-23 21:46:45 +00:00
|
|
|
if isinstance(positive_conditioning, dict):
|
|
|
|
assert isinstance(neutral_conditioning, dict)
|
|
|
|
conditioning_in = {}
|
|
|
|
for k in positive_conditioning:
|
|
|
|
if isinstance(positive_conditioning[k], list):
|
|
|
|
conditioning_in[k] = [
|
|
|
|
torch.cat([neutral_conditioning[k][i], positive_conditioning[k][i]])
|
|
|
|
for i in range(len(positive_conditioning[k]))
|
|
|
|
]
|
|
|
|
else:
|
|
|
|
conditioning_in[k] = torch.cat(
|
|
|
|
[neutral_conditioning[k], positive_conditioning[k]]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
|
2022-10-06 04:43:00 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred_neutral, noise_pred_positive = denoise_func(
|
2022-10-06 04:43:00 +00:00
|
|
|
noisy_latent_in, time_encoding_in, conditioning_in
|
|
|
|
).chunk(2)
|
|
|
|
|
|
|
|
amplified_noise_pred = signal_amplification * (
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred_positive - noise_pred_neutral
|
2022-10-06 04:43:00 +00:00
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred = noise_pred_neutral + amplified_noise_pred
|
2022-10-06 04:43:00 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
log_latent(noise_pred_neutral, "noise_pred_neutral")
|
|
|
|
log_latent(noise_pred_positive, "noise_pred_positive")
|
|
|
|
log_latent(noise_pred, "noise_pred")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
return noise_pred
|
2022-10-13 06:45:08 +00:00
|
|
|
|
|
|
|
|
2022-10-13 08:31:25 +00:00
|
|
|
def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
|
|
|
|
"""
|
|
|
|
Apply a mask to the noisy_latent.
|
|
|
|
|
|
|
|
ts is a decreasing value between 1000 and 1
|
|
|
|
"""
|
|
|
|
assert orig_latent is not None
|
2022-10-14 03:49:48 +00:00
|
|
|
log_latent(orig_latent, "orig_latent")
|
2022-10-13 08:31:25 +00:00
|
|
|
noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)
|
|
|
|
|
|
|
|
# this helps prevent the weird disjointed images that can happen with masking
|
2022-10-14 03:49:48 +00:00
|
|
|
hint_strength = 1
|
2022-10-13 08:31:25 +00:00
|
|
|
# if we're in the first 10% of the steps then don't fully noise the parts
|
|
|
|
# of the image we're not changing so that the algorithm can learn from the context
|
2022-10-16 23:42:46 +00:00
|
|
|
if ts > 1000:
|
2022-10-14 03:49:48 +00:00
|
|
|
hinted_orig_latent = (
|
2022-10-13 08:31:25 +00:00
|
|
|
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
|
|
|
|
)
|
2022-10-14 03:49:48 +00:00
|
|
|
log_latent(hinted_orig_latent, f"hinted_orig_latent {ts}")
|
2022-10-13 08:31:25 +00:00
|
|
|
else:
|
2022-10-14 03:49:48 +00:00
|
|
|
hinted_orig_latent = noised_orig_latent
|
2022-10-23 21:46:45 +00:00
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
hinted_orig_latent_masked = hinted_orig_latent * mask
|
|
|
|
log_latent(hinted_orig_latent_masked, f"hinted_orig_latent_masked {ts}")
|
|
|
|
noisy_latent_masked = (1.0 - mask) * noisy_latent
|
|
|
|
log_latent(noisy_latent_masked, f"noisy_latent_masked {ts}")
|
|
|
|
noisy_latent = hinted_orig_latent_masked + noisy_latent_masked
|
2022-10-13 08:31:25 +00:00
|
|
|
log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
|
|
|
|
return noisy_latent
|
|
|
|
|
|
|
|
|
2022-10-13 06:45:08 +00:00
|
|
|
def to_torch(x):
|
|
|
|
return x.clone().detach().to(torch.float32).to(get_device())
|
|
|
|
|
|
|
|
|
|
|
|
class NoiseSchedule:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_num_timesteps,
|
|
|
|
model_alphas_cumprod,
|
|
|
|
ddim_num_steps,
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_eta=0.0,
|
|
|
|
):
|
|
|
|
device = get_device()
|
|
|
|
if model_alphas_cumprod.shape[0] != model_num_timesteps:
|
|
|
|
raise ValueError("alphas have to be defined for each timestep")
|
|
|
|
|
|
|
|
self.alphas_cumprod = to_torch(model_alphas_cumprod)
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
np.sqrt(1.0 - model_alphas_cumprod.cpu())
|
|
|
|
)
|
|
|
|
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
|
num_ddim_timesteps=ddim_num_steps,
|
|
|
|
num_ddpm_timesteps=model_num_timesteps,
|
|
|
|
)
|
|
|
|
|
|
|
|
# ddim sampling parameters
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
alphacums=model_alphas_cumprod.cpu(),
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
eta=ddim_eta,
|
|
|
|
)
|
|
|
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
|
|
|
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
|
|
|
|
self.ddim_alphas_prev = ddim_alphas_prev
|
|
|
|
self.ddim_sqrt_one_minus_alphas = (
|
|
|
|
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
|
|
|
|
)
|
2022-10-14 03:49:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def noise_an_image(init_latent, t, schedule, noise=None):
|
|
|
|
# fast, but does not allow for exact reconstruction
|
|
|
|
# t serves as an index to gather the correct alphas
|
|
|
|
t = t.clamp(0, 1000)
|
|
|
|
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
|
|
|
|
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
|
|
|
|
|
|
|
|
if noise is None:
|
|
|
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
|
|
|
return (
|
|
|
|
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
|
|
|
|
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
|
|
|
* noise
|
|
|
|
)
|