mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
refactor: begin to standardize samplers
This commit is contained in:
parent
62e4e9cc9d
commit
9ba302a5f4
@ -26,6 +26,7 @@ from imaginairy.img_log import (
|
|||||||
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
|
||||||
from imaginairy.safety import is_nsfw
|
from imaginairy.safety import is_nsfw
|
||||||
from imaginairy.samplers.base import get_sampler
|
from imaginairy.samplers.base import get_sampler
|
||||||
|
from imaginairy.samplers.plms import PLMSSchedule
|
||||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||||
from imaginairy.utils import (
|
from imaginairy.utils import (
|
||||||
fix_torch_group_norm,
|
fix_torch_group_norm,
|
||||||
@ -208,6 +209,7 @@ def imagine(
|
|||||||
log_conditioning(c, "positive conditioning")
|
log_conditioning(c, "positive conditioning")
|
||||||
|
|
||||||
shape = [
|
shape = [
|
||||||
|
1,
|
||||||
latent_channels,
|
latent_channels,
|
||||||
prompt.height // downsampling_factor,
|
prompt.height // downsampling_factor,
|
||||||
prompt.width // downsampling_factor,
|
prompt.width // downsampling_factor,
|
||||||
@ -228,9 +230,6 @@ def imagine(
|
|||||||
if prompt.init_image:
|
if prompt.init_image:
|
||||||
generation_strength = 1 - prompt.init_image_strength
|
generation_strength = 1 - prompt.init_image_strength
|
||||||
t_enc = int(prompt.steps * generation_strength)
|
t_enc = int(prompt.steps * generation_strength)
|
||||||
sampler.make_schedule(
|
|
||||||
ddim_num_steps=prompt.steps, ddim_eta=ddim_eta
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
init_image = pillow_fit_image_within(
|
init_image = pillow_fit_image_within(
|
||||||
prompt.init_image,
|
prompt.init_image,
|
||||||
@ -284,24 +283,35 @@ def imagine(
|
|||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
seed_everything(prompt.seed)
|
seed_everything(prompt.seed)
|
||||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||||
|
schedule = PLMSSchedule(
|
||||||
|
ddpm_num_timesteps=model.num_timesteps,
|
||||||
|
ddim_num_steps=prompt.steps,
|
||||||
|
alphas_cumprod=model.alphas_cumprod,
|
||||||
|
alphas_cumprod_prev=model.alphas_cumprod_prev,
|
||||||
|
betas=model.betas,
|
||||||
|
ddim_discretize="uniform",
|
||||||
|
ddim_eta=0.0,
|
||||||
|
)
|
||||||
if generation_strength >= 1:
|
if generation_strength >= 1:
|
||||||
# prompt strength gets converted to time encodings,
|
# prompt strength gets converted to time encodings,
|
||||||
# which means you can't get to true 0 without this hack
|
# which means you can't get to true 0 without this hack
|
||||||
# (or setting steps=1000)
|
# (or setting steps=1000)
|
||||||
z_enc = noise
|
z_enc = noise
|
||||||
else:
|
else:
|
||||||
z_enc = sampler.stochastic_encode(
|
z_enc = sampler.noise_an_image(
|
||||||
init_latent,
|
init_latent,
|
||||||
torch.tensor([t_enc - 1]).to(get_device()),
|
torch.tensor([t_enc - 1]).to(get_device()),
|
||||||
|
schedule=schedule,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
)
|
)
|
||||||
log_latent(z_enc, "z_enc")
|
log_latent(z_enc, "z_enc")
|
||||||
|
|
||||||
# decode it
|
# decode it
|
||||||
samples = sampler.decode(
|
samples = sampler.decode(
|
||||||
x_latent=z_enc,
|
initial_latent=z_enc,
|
||||||
cond=c,
|
cond=c,
|
||||||
t_start=t_enc,
|
t_start=t_enc,
|
||||||
|
schedule=schedule,
|
||||||
unconditional_guidance_scale=prompt.prompt_strength,
|
unconditional_guidance_scale=prompt.prompt_strength,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
img_callback=_img_callback,
|
img_callback=_img_callback,
|
||||||
|
@ -56,7 +56,7 @@ def get_img_mask(
|
|||||||
mask[mask >= 0.5] = 1
|
mask[mask >= 0.5] = 1
|
||||||
log_img(mask, f"mask threshold {0.5}")
|
log_img(mask, f"mask threshold {0.5}")
|
||||||
|
|
||||||
mask_np = mask.cpu().numpy()
|
mask_np = mask.to(torch.float32).cpu().numpy()
|
||||||
smoother_strength = 2
|
smoother_strength = 2
|
||||||
# grow the mask area to make sure we've masked the thing we care about
|
# grow the mask area to make sure we've masked the thing we care about
|
||||||
for _ in range(smoother_strength):
|
for _ in range(smoother_strength):
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.img_log import log_latent
|
||||||
|
|
||||||
SAMPLER_TYPE_OPTIONS = [
|
SAMPLER_TYPE_OPTIONS = [
|
||||||
"plms",
|
"plms",
|
||||||
@ -51,11 +51,19 @@ class CFGDenoiser(nn.Module):
|
|||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
|
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
|
||||||
x_in = torch.cat([x] * 2)
|
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
|
||||||
sigma_in = torch.cat([sigma] * 2)
|
return self.inner_model(
|
||||||
cond_in = torch.cat([uncond, cond])
|
noisy_latent_in, time_encoding_in, cond=conditioning_in
|
||||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
)
|
||||||
denoised = uncond + (cond - uncond) * cond_scale
|
|
||||||
|
denoised = get_noise_prediction(
|
||||||
|
denoise_func=_wrapper,
|
||||||
|
noisy_latent=x,
|
||||||
|
time_encoding=sigma,
|
||||||
|
neutral_conditioning=uncond,
|
||||||
|
positive_conditioning=cond,
|
||||||
|
signal_amplification=cond_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert orig_latent is not None
|
assert orig_latent is not None
|
||||||
@ -65,51 +73,37 @@ class CFGDenoiser(nn.Module):
|
|||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
class DiffusionSampler:
|
def ensure_4_dim(t: torch.Tensor):
|
||||||
"""
|
if len(t.shape) == 3:
|
||||||
wip
|
t = t.unsqueeze(dim=0)
|
||||||
|
return t
|
||||||
|
|
||||||
hope to enforce an api upon samplers
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, noise_prediction_model, sampler_func, device=get_device()):
|
def get_noise_prediction(
|
||||||
self.noise_prediction_model = noise_prediction_model
|
denoise_func,
|
||||||
self.cfg_noise_prediction_model = CFGDenoiser(noise_prediction_model)
|
noisy_latent,
|
||||||
self.sampler_func = sampler_func
|
time_encoding,
|
||||||
self.device = device
|
neutral_conditioning,
|
||||||
|
positive_conditioning,
|
||||||
|
signal_amplification=7.5,
|
||||||
|
):
|
||||||
|
noisy_latent = ensure_4_dim(noisy_latent)
|
||||||
|
|
||||||
def zzsample(
|
noisy_latent_in = torch.cat([noisy_latent] * 2)
|
||||||
self,
|
time_encoding_in = torch.cat([time_encoding] * 2)
|
||||||
num_steps,
|
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
|
||||||
text_conditioning,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning,
|
|
||||||
eta,
|
|
||||||
initial_noise_tensor=None,
|
|
||||||
img_callback=None,
|
|
||||||
):
|
|
||||||
size = (batch_size, *shape)
|
|
||||||
|
|
||||||
initial_noise_tensor = (
|
pred_noise_neutral, pred_noise_positive = denoise_func(
|
||||||
torch.randn(size, device="cpu").to(get_device())
|
noisy_latent_in, time_encoding_in, conditioning_in
|
||||||
if initial_noise_tensor is None
|
).chunk(2)
|
||||||
else initial_noise_tensor
|
|
||||||
|
amplified_noise_pred = signal_amplification * (
|
||||||
|
pred_noise_positive - pred_noise_neutral
|
||||||
)
|
)
|
||||||
sigmas = self.noise_prediction_model.get_sigmas(num_steps)
|
pred_noise = pred_noise_neutral + amplified_noise_pred
|
||||||
x = initial_noise_tensor * sigmas[0]
|
|
||||||
|
|
||||||
samples = self.sampler_func(
|
log_latent(pred_noise_neutral, "neutral noise prediction")
|
||||||
self.cfg_noise_prediction_model,
|
log_latent(pred_noise_positive, "positive noise prediction")
|
||||||
x,
|
log_latent(pred_noise, "noise prediction")
|
||||||
sigmas,
|
|
||||||
extra_args={
|
|
||||||
"cond": text_conditioning,
|
|
||||||
"uncond": unconditional_conditioning,
|
|
||||||
"cond_scale": unconditional_guidance_scale,
|
|
||||||
},
|
|
||||||
disable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return samples, None
|
return pred_noise
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# pylama:ignore=W0613
|
# pylama:ignore=W0613
|
||||||
"""SAMPLING ONLY."""
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,41 +12,19 @@ from imaginairy.modules.diffusion.util import (
|
|||||||
make_ddim_timesteps,
|
make_ddim_timesteps,
|
||||||
noise_like,
|
noise_like,
|
||||||
)
|
)
|
||||||
|
from imaginairy.samplers.base import get_noise_prediction
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler:
|
class DDIMSchedule:
|
||||||
"""
|
def __init__(
|
||||||
Denoising Diffusion Implicit Models
|
self,
|
||||||
|
|
||||||
https://arxiv.org/abs/2010.02502
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
|
|
||||||
buffers = self._make_schedule(
|
|
||||||
model_num_timesteps=self.model.num_timesteps,
|
|
||||||
model_alphas_cumprod=self.model.alphas_cumprod,
|
|
||||||
model_betas=self.model.betas,
|
|
||||||
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
|
||||||
ddim_num_steps=ddim_num_steps,
|
|
||||||
ddim_discretize=ddim_discretize,
|
|
||||||
ddim_eta=ddim_eta,
|
|
||||||
device=self.model.device,
|
|
||||||
)
|
|
||||||
for k, v in buffers.items():
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_schedule(
|
|
||||||
model_num_timesteps,
|
model_num_timesteps,
|
||||||
model_alphas_cumprod,
|
model_alphas_cumprod,
|
||||||
model_betas,
|
|
||||||
model_alphas_cumprod_prev,
|
model_alphas_cumprod_prev,
|
||||||
|
model_betas,
|
||||||
ddim_num_steps,
|
ddim_num_steps,
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
ddim_eta=0.0,
|
||||||
@ -71,41 +48,37 @@ class DDIMSampler:
|
|||||||
ddim_timesteps=ddim_timesteps,
|
ddim_timesteps=ddim_timesteps,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
)
|
)
|
||||||
|
self.ddim_timesteps = ddim_timesteps
|
||||||
buffers = {
|
self.betas = to_torch(model_betas)
|
||||||
"ddim_timesteps": ddim_timesteps,
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
||||||
"betas": to_torch(model_betas),
|
self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
|
||||||
"alphas_cumprod": to_torch(alphas_cumprod),
|
|
||||||
"alphas_cumprod_prev": to_torch(model_alphas_cumprod_prev),
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
"sqrt_alphas_cumprod": to_torch(np.sqrt(alphas_cumprod.cpu())),
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
"sqrt_one_minus_alphas_cumprod": to_torch(
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
||||||
np.sqrt(1.0 - alphas_cumprod.cpu())
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
||||||
),
|
|
||||||
"log_one_minus_alphas_cumprod": to_torch(
|
|
||||||
np.log(1.0 - alphas_cumprod.cpu())
|
|
||||||
),
|
|
||||||
"sqrt_recip_alphas_cumprod": to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
|
|
||||||
"sqrt_recipm1_alphas_cumprod": to_torch(
|
|
||||||
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
|
||||||
),
|
|
||||||
"ddim_sigmas": ddim_sigmas.to(torch.float32).to(device),
|
|
||||||
"ddim_alphas": ddim_alphas.to(torch.float32).to(device),
|
|
||||||
"ddim_alphas_prev": ddim_alphas_prev,
|
|
||||||
"ddim_sqrt_one_minus_alphas": np.sqrt(1.0 - ddim_alphas)
|
|
||||||
.to(torch.float32)
|
|
||||||
.to(device),
|
|
||||||
}
|
|
||||||
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - buffers["alphas_cumprod_prev"])
|
|
||||||
/ (1 - buffers["alphas_cumprod"])
|
|
||||||
* (1 - buffers["alphas_cumprod"] / buffers["alphas_cumprod_prev"])
|
|
||||||
)
|
)
|
||||||
buffers[
|
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||||
"ddim_sigmas_for_original_num_steps"
|
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||||
] = sigmas_for_original_sampling_steps
|
self.sqrt_recipm1_alphas_cumprod = to_torch(
|
||||||
return buffers
|
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
||||||
|
)
|
||||||
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
|
||||||
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
|
||||||
|
self.ddim_alphas_prev = ddim_alphas_prev
|
||||||
|
self.ddim_sqrt_one_minus_alphas = (
|
||||||
|
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DDIMSampler:
|
||||||
|
"""
|
||||||
|
Denoising Diffusion Implicit Models
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2010.02502
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(
|
def sample(
|
||||||
@ -123,31 +96,30 @@ class DDIMSampler:
|
|||||||
x0=None,
|
x0=None,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
noise_dropout=0.0,
|
noise_dropout=0.0,
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
x_T=None,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
):
|
):
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
logger.warning(
|
|
||||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||||
)
|
)
|
||||||
|
schedule = DDIMSchedule(
|
||||||
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
|
model_num_timesteps=self.model.num_timesteps,
|
||||||
|
model_alphas_cumprod=self.model.alphas_cumprod,
|
||||||
|
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
||||||
|
model_betas=self.model.betas,
|
||||||
|
ddim_num_steps=num_steps,
|
||||||
|
ddim_discretize="uniform",
|
||||||
|
ddim_eta=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
samples = self.ddim_sampling(
|
samples = self.ddim_sampling(
|
||||||
conditioning,
|
conditioning,
|
||||||
shape=(batch_size, *shape),
|
shape=shape,
|
||||||
|
schedule=schedule,
|
||||||
callback=callback,
|
callback=callback,
|
||||||
img_callback=img_callback,
|
img_callback=img_callback,
|
||||||
quantize_denoised=quantize_x0,
|
quantize_denoised=quantize_x0,
|
||||||
@ -155,8 +127,6 @@ class DDIMSampler:
|
|||||||
x0=x0,
|
x0=x0,
|
||||||
noise_dropout=noise_dropout,
|
noise_dropout=noise_dropout,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
x_T=x_T,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
@ -168,6 +138,7 @@ class DDIMSampler:
|
|||||||
self,
|
self,
|
||||||
cond,
|
cond,
|
||||||
shape,
|
shape,
|
||||||
|
schedule,
|
||||||
x_T=None,
|
x_T=None,
|
||||||
callback=None,
|
callback=None,
|
||||||
timesteps=None,
|
timesteps=None,
|
||||||
@ -177,8 +148,6 @@ class DDIMSampler:
|
|||||||
img_callback=None,
|
img_callback=None,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
noise_dropout=0.0,
|
noise_dropout=0.0,
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
):
|
):
|
||||||
@ -192,16 +161,16 @@ class DDIMSampler:
|
|||||||
log_latent(img, "initial noise")
|
log_latent(img, "initial noise")
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = self.ddim_timesteps
|
timesteps = schedule.ddim_timesteps
|
||||||
else:
|
else:
|
||||||
subset_end = (
|
subset_end = (
|
||||||
int(
|
int(
|
||||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
min(timesteps / schedule.ddim_timesteps.shape[0], 1)
|
||||||
* self.ddim_timesteps.shape[0]
|
* schedule.ddim_timesteps.shape[0]
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
timesteps = schedule.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
@ -225,6 +194,7 @@ class DDIMSampler:
|
|||||||
cond,
|
cond,
|
||||||
ts,
|
ts,
|
||||||
index=index,
|
index=index,
|
||||||
|
schedule=schedule,
|
||||||
quantize_denoised=quantize_denoised,
|
quantize_denoised=quantize_denoised,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
noise_dropout=noise_dropout,
|
noise_dropout=noise_dropout,
|
||||||
@ -245,6 +215,7 @@ class DDIMSampler:
|
|||||||
c,
|
c,
|
||||||
t,
|
t,
|
||||||
index,
|
index,
|
||||||
|
schedule,
|
||||||
repeat_noise=False,
|
repeat_noise=False,
|
||||||
quantize_denoised=False,
|
quantize_denoised=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
@ -254,26 +225,26 @@ class DDIMSampler:
|
|||||||
loss_function=None,
|
loss_function=None,
|
||||||
):
|
):
|
||||||
assert unconditional_guidance_scale >= 1
|
assert unconditional_guidance_scale >= 1
|
||||||
x_in = torch.cat([x] * 2)
|
noise_pred = get_noise_prediction(
|
||||||
t_in = torch.cat([t] * 2)
|
denoise_func=self.model.apply_model,
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
noisy_latent=x,
|
||||||
# with torch.no_grad():
|
time_encoding=t,
|
||||||
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk(
|
neutral_conditioning=unconditional_conditioning,
|
||||||
2
|
positive_conditioning=c,
|
||||||
)
|
signal_amplification=unconditional_guidance_scale,
|
||||||
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (
|
|
||||||
noise_pred - noise_pred_uncond
|
|
||||||
)
|
)
|
||||||
|
|
||||||
b = x.shape[0]
|
b = x.shape[0]
|
||||||
log_latent(noise_pred, "noise prediction")
|
log_latent(noise_pred, "noise prediction")
|
||||||
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
# select parameters corresponding to the currently considered timestep
|
||||||
a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=x.device)
|
a_t = torch.full((b, 1, 1, 1), schedule.ddim_alphas[index], device=x.device)
|
||||||
a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=x.device)
|
a_prev = torch.full(
|
||||||
sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=x.device)
|
(b, 1, 1, 1), schedule.ddim_alphas_prev[index], device=x.device
|
||||||
|
)
|
||||||
|
sigma_t = torch.full((b, 1, 1, 1), schedule.ddim_sigmas[index], device=x.device)
|
||||||
sqrt_one_minus_at = torch.full(
|
sqrt_one_minus_at = torch.full(
|
||||||
(b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index], device=x.device
|
(b, 1, 1, 1), schedule.ddim_sqrt_one_minus_alphas[index], device=x.device
|
||||||
)
|
)
|
||||||
return self._p_sample_ddim_formula(
|
return self._p_sample_ddim_formula(
|
||||||
x,
|
x,
|
||||||
@ -310,12 +281,11 @@ class DDIMSampler:
|
|||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stochastic_encode(self, init_latent, t, noise=None):
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
||||||
# fast, but does not allow for exact reconstruction
|
|
||||||
# t serves as an index to gather the correct alphas
|
# t serves as an index to gather the correct alphas
|
||||||
t = t.clamp(0, 1000)
|
t = t.clamp(0, 1000)
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
|
||||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||||
@ -328,31 +298,34 @@ class DDIMSampler:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
x_latent,
|
initial_latent,
|
||||||
cond,
|
cond,
|
||||||
t_start,
|
t_start,
|
||||||
|
schedule,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
score_corrector=None,
|
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
orig_latent=None,
|
orig_latent=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
timesteps = self.ddim_timesteps[:t_start]
|
timesteps = schedule.ddim_timesteps[:t_start]
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
logger.debug(f"Running DDIM Sampling with {total_steps} timesteps")
|
logger.debug(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = initial_latent
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full(
|
ts = torch.full(
|
||||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
(initial_latent.shape[0],),
|
||||||
|
step,
|
||||||
|
device=initial_latent.device,
|
||||||
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -374,17 +347,12 @@ class DDIMSampler:
|
|||||||
x_dec,
|
x_dec,
|
||||||
cond,
|
cond,
|
||||||
ts,
|
ts,
|
||||||
|
schedule=schedule,
|
||||||
index=index,
|
index=index,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
)
|
)
|
||||||
# original_loss = ((x_dec - x_latent).abs().mean()*70)
|
|
||||||
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
|
|
||||||
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
|
||||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
|
||||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
|
||||||
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
|
||||||
|
|
||||||
log_latent(x_dec, f"x_dec {i}")
|
log_latent(x_dec, f"x_dec {i}")
|
||||||
log_latent(pred_x0, f"pred_x0 {i}")
|
log_latent(pred_x0, f"pred_x0 {i}")
|
||||||
|
@ -8,11 +8,15 @@ from imaginairy.vendored.k_diffusion import sampling as k_sampling
|
|||||||
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
|
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
|
||||||
|
|
||||||
|
|
||||||
|
class StandardCompVisDenoiser(CompVisDenoiser):
|
||||||
|
def apply_model(self, *args, **kwargs):
|
||||||
|
return self.inner_model.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class KDiffusionSampler:
|
class KDiffusionSampler:
|
||||||
def __init__(self, model, sampler_name):
|
def __init__(self, model, sampler_name):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.cv_denoiser = CompVisDenoiser(model)
|
self.cv_denoiser = StandardCompVisDenoiser(model)
|
||||||
# self.cfg_denoiser = CompVisDenoiser(self.cv_denoiser)
|
|
||||||
self.sampler_name = sampler_name
|
self.sampler_name = sampler_name
|
||||||
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
|
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
|
||||||
|
|
||||||
@ -28,10 +32,8 @@ class KDiffusionSampler:
|
|||||||
initial_noise_tensor=None,
|
initial_noise_tensor=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
):
|
):
|
||||||
size = (batch_size, *shape)
|
|
||||||
|
|
||||||
initial_noise_tensor = (
|
initial_noise_tensor = (
|
||||||
torch.randn(size, device="cpu").to(get_device())
|
torch.randn(shape, device="cpu").to(get_device())
|
||||||
if initial_noise_tensor is None
|
if initial_noise_tensor is None
|
||||||
else initial_noise_tensor
|
else initial_noise_tensor
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
# pylama:ignore=W0613
|
# pylama:ignore=W0613
|
||||||
"""SAMPLING ONLY."""
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,65 +12,52 @@ from imaginairy.modules.diffusion.util import (
|
|||||||
make_ddim_timesteps,
|
make_ddim_timesteps,
|
||||||
noise_like,
|
noise_like,
|
||||||
)
|
)
|
||||||
|
from imaginairy.samplers.base import get_noise_prediction
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler:
|
class PLMSSchedule:
|
||||||
"""probabilistic least-mean-squares"""
|
def __init__(
|
||||||
|
self,
|
||||||
def __init__(self, model):
|
ddpm_num_timesteps, # 1000?
|
||||||
self.model = model
|
ddim_num_steps, # prompt.steps?
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
alphas_cumprod,
|
||||||
self.device_available = get_device()
|
alphas_cumprod_prev,
|
||||||
self.ddim_timesteps = None
|
betas,
|
||||||
|
ddim_discretize="uniform",
|
||||||
def register_buffer(self, name, attr):
|
ddim_eta=0.0,
|
||||||
if isinstance(attr, torch.Tensor):
|
):
|
||||||
if attr.device != torch.device(self.device_available):
|
|
||||||
attr = attr.to(torch.float32).to(torch.device(self.device_available))
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
|
|
||||||
if ddim_eta != 0:
|
if ddim_eta != 0:
|
||||||
raise ValueError("ddim_eta must be 0 for PLMS")
|
raise ValueError("ddim_eta must be 0 for PLMS")
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
device = get_device()
|
||||||
ddim_discr_method=ddim_discretize,
|
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
|
||||||
)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod
|
|
||||||
assert (
|
assert (
|
||||||
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
alphas_cumprod.shape[0] == ddpm_num_timesteps
|
||||||
), "alphas have to be defined for each timestep"
|
), "alphas have to be defined for each timestep"
|
||||||
|
|
||||||
def to_torch(x):
|
def to_torch(x):
|
||||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
return x.clone().detach().to(torch.float32).to(device)
|
||||||
|
|
||||||
self.register_buffer("betas", to_torch(self.model.betas))
|
|
||||||
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer(
|
|
||||||
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
self.betas = to_torch(betas)
|
||||||
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
||||||
|
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.register_buffer(
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
||||||
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
||||||
"sqrt_one_minus_alphas_cumprod",
|
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
||||||
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
self.sqrt_recipm1_alphas_cumprod = to_torch(
|
||||||
|
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
|
||||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
)
|
ddim_discr_method=ddim_discretize,
|
||||||
self.register_buffer(
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
num_ddpm_timesteps=ddpm_num_timesteps,
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recipm1_alphas_cumprod",
|
|
||||||
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# ddim sampling parameters
|
# ddim sampling parameters
|
||||||
@ -80,19 +66,21 @@ class PLMSSampler:
|
|||||||
ddim_timesteps=self.ddim_timesteps,
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
)
|
)
|
||||||
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
||||||
self.register_buffer("ddim_alphas", ddim_alphas)
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
||||||
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
self.ddim_alphas_prev = ddim_alphas_prev
|
||||||
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
self.ddim_sqrt_one_minus_alphas = (
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(torch.device(device))
|
||||||
(1 - self.alphas_cumprod_prev)
|
|
||||||
/ (1 - self.alphas_cumprod)
|
|
||||||
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PLMSSampler:
|
||||||
|
"""probabilistic least-mean-squares"""
|
||||||
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.device = get_device()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
@ -101,145 +89,90 @@ class PLMSSampler:
|
|||||||
shape,
|
shape,
|
||||||
conditioning=None,
|
conditioning=None,
|
||||||
callback=None,
|
callback=None,
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
quantize_x0=False,
|
quantize_x0=False,
|
||||||
eta=0.0,
|
eta=0.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
x0=None,
|
orig_latent=None,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
noise_dropout=0.0,
|
noise_dropout=0.0,
|
||||||
score_corrector=None,
|
initial_latent=None,
|
||||||
corrector_kwargs=None,
|
|
||||||
x_T=None,
|
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
|
timesteps=None,
|
||||||
|
quantize_denoised=False,
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
logger.warning(
|
|
||||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
if conditioning.shape[0] != batch_size:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta)
|
schedule = PLMSSchedule(
|
||||||
|
ddpm_num_timesteps=self.model.num_timesteps,
|
||||||
samples = self.plms_sampling(
|
ddim_num_steps=num_steps,
|
||||||
conditioning,
|
alphas_cumprod=self.model.alphas_cumprod,
|
||||||
(batch_size, *shape),
|
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
||||||
callback=callback,
|
betas=self.model.betas,
|
||||||
img_callback=img_callback,
|
ddim_discretize="uniform",
|
||||||
quantize_denoised=quantize_x0,
|
ddim_eta=0.0,
|
||||||
mask=mask,
|
|
||||||
x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
)
|
)
|
||||||
return samples
|
device = self.device
|
||||||
|
# batch_size = shape[0]
|
||||||
@torch.no_grad()
|
if initial_latent is None:
|
||||||
def plms_sampling(
|
initial_latent = torch.randn(shape, device="cpu").to(device)
|
||||||
self,
|
log_latent(initial_latent, "initial latent")
|
||||||
cond,
|
|
||||||
shape,
|
|
||||||
x_T=None,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
callback=None,
|
|
||||||
timesteps=None,
|
|
||||||
quantize_denoised=False,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
img_callback=None,
|
|
||||||
temperature=1.0,
|
|
||||||
noise_dropout=0.0,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1.0,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
):
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = shape[0]
|
|
||||||
if x_T is None:
|
|
||||||
|
|
||||||
img = torch.randn(shape, device="cpu").to(device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
log_latent(img, "initial img")
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = (
|
timesteps = schedule.ddim_timesteps
|
||||||
self.ddpm_num_timesteps
|
elif timesteps is not None:
|
||||||
if ddim_use_original_steps
|
|
||||||
else self.ddim_timesteps
|
|
||||||
)
|
|
||||||
elif timesteps is not None and not ddim_use_original_steps:
|
|
||||||
subset_end = (
|
subset_end = (
|
||||||
int(
|
int(
|
||||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
min(timesteps / schedule.ddim_timesteps.shape[0], 1)
|
||||||
* self.ddim_timesteps.shape[0]
|
* schedule.ddim_timesteps.shape[0]
|
||||||
)
|
)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
timesteps = schedule.ddim_timesteps[:subset_end]
|
||||||
|
|
||||||
time_range = (
|
time_range = np.flip(timesteps)
|
||||||
list(reversed(range(0, timesteps)))
|
total_steps = timesteps.shape[0]
|
||||||
if ddim_use_original_steps
|
|
||||||
else np.flip(timesteps)
|
|
||||||
)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
logger.debug(f"Running PLMS Sampling with {total_steps} timesteps")
|
logger.debug(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps)
|
iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps)
|
||||||
old_eps = []
|
old_eps = []
|
||||||
|
img = initial_latent
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
|
||||||
ts_next = torch.full(
|
ts_next = torch.full(
|
||||||
(b,),
|
(batch_size,),
|
||||||
time_range[min(i + 1, len(time_range) - 1)],
|
time_range[min(i + 1, len(time_range) - 1)],
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert x0 is not None
|
assert orig_latent is not None
|
||||||
img_orig = self.model.q_sample(
|
img_orig = self.model.q_sample(orig_latent, ts)
|
||||||
x0, ts
|
|
||||||
) # TODO: deterministic forward pass?
|
|
||||||
img = img_orig * mask + (1.0 - mask) * img
|
img = img_orig * mask + (1.0 - mask) * img
|
||||||
|
|
||||||
img, pred_x0, e_t = self.p_sample_plms(
|
img, pred_x0, noise_prediction = self.p_sample_plms(
|
||||||
img,
|
img,
|
||||||
cond,
|
conditioning,
|
||||||
ts,
|
ts,
|
||||||
|
schedule=schedule,
|
||||||
index=index,
|
index=index,
|
||||||
use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised,
|
quantize_denoised=quantize_denoised,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
noise_dropout=noise_dropout,
|
noise_dropout=noise_dropout,
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
old_eps=old_eps,
|
old_eps=old_eps,
|
||||||
t_next=ts_next,
|
t_next=ts_next,
|
||||||
)
|
)
|
||||||
old_eps.append(e_t)
|
old_eps.append(noise_prediction)
|
||||||
if len(old_eps) >= 4:
|
if len(old_eps) >= 4:
|
||||||
old_eps.pop(0)
|
old_eps.pop(0)
|
||||||
if callback:
|
if callback:
|
||||||
@ -253,119 +186,108 @@ class PLMSSampler:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_plms(
|
def p_sample_plms(
|
||||||
self,
|
self,
|
||||||
x,
|
noisy_latent,
|
||||||
c,
|
positive_conditioning,
|
||||||
t,
|
time_encoding,
|
||||||
|
schedule: PLMSSchedule,
|
||||||
index,
|
index,
|
||||||
repeat_noise=False,
|
repeat_noise=False,
|
||||||
use_original_steps=False,
|
|
||||||
quantize_denoised=False,
|
quantize_denoised=False,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
noise_dropout=0.0,
|
noise_dropout=0.0,
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
old_eps=None,
|
old_eps=None,
|
||||||
t_next=None,
|
t_next=None,
|
||||||
):
|
):
|
||||||
b, *_, device = *x.shape, x.device
|
batch_size = noisy_latent.shape[0]
|
||||||
|
noise_prediction = get_noise_prediction(
|
||||||
def get_model_output(x, t):
|
denoise_func=self.model.apply_model,
|
||||||
if (
|
noisy_latent=noisy_latent,
|
||||||
unconditional_conditioning is None
|
time_encoding=time_encoding,
|
||||||
or unconditional_guidance_scale == 1.0
|
neutral_conditioning=unconditional_conditioning,
|
||||||
):
|
positive_conditioning=positive_conditioning,
|
||||||
e_t = self.model.apply_model(x, t, c)
|
signal_amplification=unconditional_guidance_scale,
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
log_latent(e_t_uncond, "noise pred uncond")
|
|
||||||
log_latent(e_t, "noise pred cond")
|
|
||||||
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
log_latent(e_t, "noise pred combined")
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(
|
|
||||||
self.model, e_t, x, t, c, **corrector_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = (
|
|
||||||
self.model.alphas_cumprod_prev
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_alphas_prev
|
|
||||||
)
|
|
||||||
sqrt_one_minus_alphas = (
|
|
||||||
self.model.sqrt_one_minus_alphas_cumprod
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sqrt_one_minus_alphas
|
|
||||||
)
|
|
||||||
sigmas = (
|
|
||||||
self.model.ddim_sigmas_for_original_num_steps
|
|
||||||
if use_original_steps
|
|
||||||
else self.ddim_sigmas
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
# select parameters corresponding to the currently considered timestep
|
# select parameters corresponding to the currently considered timestep
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
alpha_at_t = torch.full(
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
(batch_size, 1, 1, 1), schedule.ddim_alphas[index], device=self.device
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
)
|
||||||
|
alpha_prev_at_t = torch.full(
|
||||||
|
(batch_size, 1, 1, 1),
|
||||||
|
schedule.ddim_alphas_prev[index],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
sigma_t = torch.full(
|
||||||
|
(batch_size, 1, 1, 1), schedule.ddim_sigmas[index], device=self.device
|
||||||
|
)
|
||||||
sqrt_one_minus_at = torch.full(
|
sqrt_one_minus_at = torch.full(
|
||||||
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
(batch_size, 1, 1, 1),
|
||||||
|
schedule.ddim_sqrt_one_minus_alphas[index],
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (noisy_latent - sqrt_one_minus_at * e_t) / alpha_at_t.sqrt()
|
||||||
if quantize_denoised:
|
if quantize_denoised:
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||||
# direction pointing to x_t
|
# direction pointing to x_t
|
||||||
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1.0 - alpha_prev_at_t - sigma_t**2).sqrt() * e_t
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
noise = (
|
||||||
|
sigma_t
|
||||||
|
* noise_like(noisy_latent.shape, self.device, repeat_noise)
|
||||||
|
* temperature
|
||||||
|
)
|
||||||
if noise_dropout > 0.0:
|
if noise_dropout > 0.0:
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
x_prev = alpha_prev_at_t.sqrt() * pred_x0 + dir_xt + noise
|
||||||
return x_prev, pred_x0
|
return x_prev, pred_x0
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
|
|
||||||
if len(old_eps) == 0:
|
if len(old_eps) == 0:
|
||||||
# Pseudo Improved Euler (2nd order)
|
# Pseudo Improved Euler (2nd order)
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_prediction, index)
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
e_t_next = get_noise_prediction(
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
denoise_func=self.model.apply_model,
|
||||||
|
noisy_latent=x_prev,
|
||||||
|
time_encoding=t_next,
|
||||||
|
neutral_conditioning=unconditional_conditioning,
|
||||||
|
positive_conditioning=positive_conditioning,
|
||||||
|
signal_amplification=unconditional_guidance_scale,
|
||||||
|
)
|
||||||
|
e_t_prime = (noise_prediction + e_t_next) / 2
|
||||||
elif len(old_eps) == 1:
|
elif len(old_eps) == 1:
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
e_t_prime = (3 * noise_prediction - old_eps[-1]) / 2
|
||||||
elif len(old_eps) == 2:
|
elif len(old_eps) == 2:
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
e_t_prime = (
|
||||||
|
23 * noise_prediction - 16 * old_eps[-1] + 5 * old_eps[-2]
|
||||||
|
) / 12
|
||||||
elif len(old_eps) >= 3:
|
elif len(old_eps) >= 3:
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||||
e_t_prime = (
|
e_t_prime = (
|
||||||
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
55 * noise_prediction
|
||||||
|
- 59 * old_eps[-1]
|
||||||
|
+ 37 * old_eps[-2]
|
||||||
|
- 9 * old_eps[-3]
|
||||||
) / 24
|
) / 24
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||||
log_latent(x_prev, "x_prev")
|
log_latent(x_prev, "x_prev")
|
||||||
log_latent(pred_x0, "pred_x0")
|
log_latent(pred_x0, "pred_x0")
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
return x_prev, pred_x0, noise_prediction
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def stochastic_encode(self, init_latent, t, noise=None):
|
def noise_an_image(self, init_latent, t, schedule, noise=None):
|
||||||
|
# replace with ddpm.q_sample?
|
||||||
# fast, but does not allow for exact reconstruction
|
# fast, but does not allow for exact reconstruction
|
||||||
# t serves as an index to gather the correct alphas
|
# t serves as an index to gather the correct alphas
|
||||||
t = t.clamp(0, 1000)
|
t = t.clamp(0, 1000)
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
|
||||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||||
@ -378,26 +300,26 @@ class PLMSSampler:
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
x_latent,
|
|
||||||
cond,
|
cond,
|
||||||
t_start,
|
schedule,
|
||||||
|
initial_latent=None,
|
||||||
|
t_start=None,
|
||||||
unconditional_guidance_scale=1.0,
|
unconditional_guidance_scale=1.0,
|
||||||
unconditional_conditioning=None,
|
unconditional_conditioning=None,
|
||||||
img_callback=None,
|
img_callback=None,
|
||||||
score_corrector=None,
|
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
orig_latent=None,
|
orig_latent=None,
|
||||||
noise=None,
|
noise=None,
|
||||||
):
|
):
|
||||||
|
device = self.device
|
||||||
timesteps = self.ddim_timesteps[:t_start]
|
timesteps = schedule.ddim_timesteps[:t_start]
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
time_range = np.flip(timesteps)
|
||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
|
iterator = tqdm(time_range, desc="PLMS img2img", total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = initial_latent
|
||||||
old_eps = []
|
old_eps = []
|
||||||
log_latent(x_dec, "x_dec")
|
log_latent(x_dec, "x_dec")
|
||||||
|
|
||||||
@ -411,12 +333,15 @@ class PLMSSampler:
|
|||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full(
|
ts = torch.full(
|
||||||
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
|
(initial_latent.shape[0],),
|
||||||
|
step,
|
||||||
|
device=initial_latent.device,
|
||||||
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
ts_next = torch.full(
|
ts_next = torch.full(
|
||||||
(x_latent.shape[0],),
|
(initial_latent.shape[0],),
|
||||||
time_range[min(i + 1, len(time_range) - 1)],
|
time_range[min(i + 1, len(time_range) - 1)],
|
||||||
device=x_latent.device,
|
device=device,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -435,10 +360,11 @@ class PLMSSampler:
|
|||||||
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
||||||
log_latent(x_dec, f"x_dec {ts}")
|
log_latent(x_dec, f"x_dec {ts}")
|
||||||
|
|
||||||
x_dec, pred_x0, e_t = self.p_sample_plms(
|
x_dec, pred_x0, noise_prediction = self.p_sample_plms(
|
||||||
x_dec,
|
x_dec,
|
||||||
cond,
|
cond,
|
||||||
ts,
|
ts,
|
||||||
|
schedule=schedule,
|
||||||
index=index,
|
index=index,
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
@ -446,14 +372,8 @@ class PLMSSampler:
|
|||||||
old_eps=old_eps,
|
old_eps=old_eps,
|
||||||
t_next=ts_next,
|
t_next=ts_next,
|
||||||
)
|
)
|
||||||
# original_loss = ((x_dec - x_latent).abs().mean()*70)
|
|
||||||
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
|
|
||||||
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
|
|
||||||
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
|
|
||||||
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
|
|
||||||
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
|
|
||||||
|
|
||||||
old_eps.append(e_t)
|
old_eps.append(noise_prediction)
|
||||||
if len(old_eps) >= 4:
|
if len(old_eps) >= 4:
|
||||||
old_eps.pop(0)
|
old_eps.pop(0)
|
||||||
|
|
||||||
|
@ -47,6 +47,11 @@ def pre_setup():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_get_device():
|
||||||
|
get_device.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def filename_base_for_outputs(request):
|
def filename_base_for_outputs(request):
|
||||||
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
|
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"
|
||||||
|
@ -54,7 +54,6 @@ def experiment_step_repeats():
|
|||||||
embedder.to(get_device())
|
embedder.to(get_device())
|
||||||
|
|
||||||
sampler = DDIMSampler(model)
|
sampler = DDIMSampler(model)
|
||||||
sampler.make_schedule(1000)
|
|
||||||
|
|
||||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
|
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
|
||||||
init_image, _, _ = pillow_img_to_torch_image(
|
init_image, _, _ = pillow_img_to_torch_image(
|
||||||
@ -89,7 +88,9 @@ def experiment_step_repeats():
|
|||||||
# noise_pred = model.apply_model(init_latent, t, neutral_embedding)
|
# noise_pred = model.apply_model(init_latent, t, neutral_embedding)
|
||||||
# log_latent(noise_pred, "noise prediction")
|
# log_latent(noise_pred, "noise prediction")
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
x_prev, pred_x0 = sampler.p_sample_ddim(x_prev, neutral_embedding, t, index)
|
x_prev, pred_x0 = sampler.p_sample_ddim( # noqa
|
||||||
|
x_prev, neutral_embedding, t, index
|
||||||
|
)
|
||||||
log_latent(pred_x0, "pred_x0")
|
log_latent(pred_x0, "pred_x0")
|
||||||
x_prev = pred_x0
|
x_prev = pred_x0
|
||||||
|
|
||||||
|
@ -12,7 +12,6 @@ from imaginairy.utils import (
|
|||||||
get_hardware_description,
|
get_hardware_description,
|
||||||
get_obj_from_str,
|
get_obj_from_str,
|
||||||
instantiate_from_config,
|
instantiate_from_config,
|
||||||
platform_appropriate_autocast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -79,6 +78,7 @@ def test_instantiate_from_config():
|
|||||||
instantiate_from_config(config)
|
instantiate_from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def test_platform_appropriate_autocast():
|
#
|
||||||
with platform_appropriate_autocast("autocast"):
|
# def test_platform_appropriate_autocast():
|
||||||
pass
|
# with platform_appropriate_autocast("autocast"):
|
||||||
|
# pass
|
||||||
|
Loading…
Reference in New Issue
Block a user