refactor: begin to standardize samplers

This commit is contained in:
Bryce 2022-10-05 21:43:00 -07:00 committed by Bryce Drennan
parent 62e4e9cc9d
commit 9ba302a5f4
9 changed files with 315 additions and 415 deletions

View File

@ -26,6 +26,7 @@ from imaginairy.img_log import (
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.safety import is_nsfw from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler from imaginairy.samplers.base import get_sampler
from imaginairy.samplers.plms import PLMSSchedule
from imaginairy.schema import ImaginePrompt, ImagineResult from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import ( from imaginairy.utils import (
fix_torch_group_norm, fix_torch_group_norm,
@ -208,6 +209,7 @@ def imagine(
log_conditioning(c, "positive conditioning") log_conditioning(c, "positive conditioning")
shape = [ shape = [
1,
latent_channels, latent_channels,
prompt.height // downsampling_factor, prompt.height // downsampling_factor,
prompt.width // downsampling_factor, prompt.width // downsampling_factor,
@ -228,9 +230,6 @@ def imagine(
if prompt.init_image: if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength) t_enc = int(prompt.steps * generation_strength)
sampler.make_schedule(
ddim_num_steps=prompt.steps, ddim_eta=ddim_eta
)
try: try:
init_image = pillow_fit_image_within( init_image = pillow_fit_image_within(
prompt.init_image, prompt.init_image,
@ -284,24 +283,35 @@ def imagine(
# encode (scaled latent) # encode (scaled latent)
seed_everything(prompt.seed) seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device()) noise = torch.randn_like(init_latent, device="cpu").to(get_device())
schedule = PLMSSchedule(
ddpm_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
alphas_cumprod=model.alphas_cumprod,
alphas_cumprod_prev=model.alphas_cumprod_prev,
betas=model.betas,
ddim_discretize="uniform",
ddim_eta=0.0,
)
if generation_strength >= 1: if generation_strength >= 1:
# prompt strength gets converted to time encodings, # prompt strength gets converted to time encodings,
# which means you can't get to true 0 without this hack # which means you can't get to true 0 without this hack
# (or setting steps=1000) # (or setting steps=1000)
z_enc = noise z_enc = noise
else: else:
z_enc = sampler.stochastic_encode( z_enc = sampler.noise_an_image(
init_latent, init_latent,
torch.tensor([t_enc - 1]).to(get_device()), torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise, noise=noise,
) )
log_latent(z_enc, "z_enc") log_latent(z_enc, "z_enc")
# decode it # decode it
samples = sampler.decode( samples = sampler.decode(
x_latent=z_enc, initial_latent=z_enc,
cond=c, cond=c,
t_start=t_enc, t_start=t_enc,
schedule=schedule,
unconditional_guidance_scale=prompt.prompt_strength, unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc, unconditional_conditioning=uc,
img_callback=_img_callback, img_callback=_img_callback,

View File

@ -56,7 +56,7 @@ def get_img_mask(
mask[mask >= 0.5] = 1 mask[mask >= 0.5] = 1
log_img(mask, f"mask threshold {0.5}") log_img(mask, f"mask threshold {0.5}")
mask_np = mask.cpu().numpy() mask_np = mask.to(torch.float32).cpu().numpy()
smoother_strength = 2 smoother_strength = 2
# grow the mask area to make sure we've masked the thing we care about # grow the mask area to make sure we've masked the thing we care about
for _ in range(smoother_strength): for _ in range(smoother_strength):

View File

@ -2,7 +2,7 @@
import torch import torch
from torch import nn from torch import nn
from imaginairy.utils import get_device from imaginairy.img_log import log_latent
SAMPLER_TYPE_OPTIONS = [ SAMPLER_TYPE_OPTIONS = [
"plms", "plms",
@ -51,11 +51,19 @@ class CFGDenoiser(nn.Module):
self.inner_model = model self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None): def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
x_in = torch.cat([x] * 2) def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
sigma_in = torch.cat([sigma] * 2) return self.inner_model(
cond_in = torch.cat([uncond, cond]) noisy_latent_in, time_encoding_in, cond=conditioning_in
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) )
denoised = uncond + (cond - uncond) * cond_scale
denoised = get_noise_prediction(
denoise_func=_wrapper,
noisy_latent=x,
time_encoding=sigma,
neutral_conditioning=uncond,
positive_conditioning=cond,
signal_amplification=cond_scale,
)
if mask is not None: if mask is not None:
assert orig_latent is not None assert orig_latent is not None
@ -65,51 +73,37 @@ class CFGDenoiser(nn.Module):
return denoised return denoised
class DiffusionSampler: def ensure_4_dim(t: torch.Tensor):
""" if len(t.shape) == 3:
wip t = t.unsqueeze(dim=0)
return t
hope to enforce an api upon samplers
"""
def __init__(self, noise_prediction_model, sampler_func, device=get_device()): def get_noise_prediction(
self.noise_prediction_model = noise_prediction_model denoise_func,
self.cfg_noise_prediction_model = CFGDenoiser(noise_prediction_model) noisy_latent,
self.sampler_func = sampler_func time_encoding,
self.device = device neutral_conditioning,
positive_conditioning,
signal_amplification=7.5,
):
noisy_latent = ensure_4_dim(noisy_latent)
def zzsample( noisy_latent_in = torch.cat([noisy_latent] * 2)
self, time_encoding_in = torch.cat([time_encoding] * 2)
num_steps, conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
text_conditioning,
batch_size,
shape,
unconditional_guidance_scale,
unconditional_conditioning,
eta,
initial_noise_tensor=None,
img_callback=None,
):
size = (batch_size, *shape)
initial_noise_tensor = ( pred_noise_neutral, pred_noise_positive = denoise_func(
torch.randn(size, device="cpu").to(get_device()) noisy_latent_in, time_encoding_in, conditioning_in
if initial_noise_tensor is None ).chunk(2)
else initial_noise_tensor
amplified_noise_pred = signal_amplification * (
pred_noise_positive - pred_noise_neutral
) )
sigmas = self.noise_prediction_model.get_sigmas(num_steps) pred_noise = pred_noise_neutral + amplified_noise_pred
x = initial_noise_tensor * sigmas[0]
samples = self.sampler_func( log_latent(pred_noise_neutral, "neutral noise prediction")
self.cfg_noise_prediction_model, log_latent(pred_noise_positive, "positive noise prediction")
x, log_latent(pred_noise, "noise prediction")
sigmas,
extra_args={
"cond": text_conditioning,
"uncond": unconditional_conditioning,
"cond_scale": unconditional_guidance_scale,
},
disable=False,
)
return samples, None return pred_noise

View File

@ -1,5 +1,4 @@
# pylama:ignore=W0613 # pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging import logging
import numpy as np import numpy as np
@ -13,41 +12,19 @@ from imaginairy.modules.diffusion.util import (
make_ddim_timesteps, make_ddim_timesteps,
noise_like, noise_like,
) )
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.utils import get_device from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DDIMSampler: class DDIMSchedule:
""" def __init__(
Denoising Diffusion Implicit Models self,
https://arxiv.org/abs/2010.02502
"""
def __init__(self, model):
self.model = model
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
buffers = self._make_schedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
model_betas=self.model.betas,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
ddim_num_steps=ddim_num_steps,
ddim_discretize=ddim_discretize,
ddim_eta=ddim_eta,
device=self.model.device,
)
for k, v in buffers.items():
setattr(self, k, v)
@staticmethod
def _make_schedule(
model_num_timesteps, model_num_timesteps,
model_alphas_cumprod, model_alphas_cumprod,
model_betas,
model_alphas_cumprod_prev, model_alphas_cumprod_prev,
model_betas,
ddim_num_steps, ddim_num_steps,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,
@ -71,41 +48,37 @@ class DDIMSampler:
ddim_timesteps=ddim_timesteps, ddim_timesteps=ddim_timesteps,
eta=ddim_eta, eta=ddim_eta,
) )
self.ddim_timesteps = ddim_timesteps
buffers = { self.betas = to_torch(model_betas)
"ddim_timesteps": ddim_timesteps, self.alphas_cumprod = to_torch(alphas_cumprod)
"betas": to_torch(model_betas), self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
"alphas_cumprod": to_torch(alphas_cumprod),
"alphas_cumprod_prev": to_torch(model_alphas_cumprod_prev),
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
"sqrt_alphas_cumprod": to_torch(np.sqrt(alphas_cumprod.cpu())), self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
"sqrt_one_minus_alphas_cumprod": to_torch( self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu()) np.sqrt(1.0 - alphas_cumprod.cpu())
),
"log_one_minus_alphas_cumprod": to_torch(
np.log(1.0 - alphas_cumprod.cpu())
),
"sqrt_recip_alphas_cumprod": to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())),
"sqrt_recipm1_alphas_cumprod": to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
),
"ddim_sigmas": ddim_sigmas.to(torch.float32).to(device),
"ddim_alphas": ddim_alphas.to(torch.float32).to(device),
"ddim_alphas_prev": ddim_alphas_prev,
"ddim_sqrt_one_minus_alphas": np.sqrt(1.0 - ddim_alphas)
.to(torch.float32)
.to(device),
}
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - buffers["alphas_cumprod_prev"])
/ (1 - buffers["alphas_cumprod"])
* (1 - buffers["alphas_cumprod"] / buffers["alphas_cumprod_prev"])
) )
buffers[ self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
"ddim_sigmas_for_original_num_steps" self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
] = sigmas_for_original_sampling_steps self.sqrt_recipm1_alphas_cumprod = to_torch(
return buffers np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
)
class DDIMSampler:
"""
Denoising Diffusion Implicit Models
https://arxiv.org/abs/2010.02502
"""
def __init__(self, model):
self.model = model
@torch.no_grad() @torch.no_grad()
def sample( def sample(
@ -123,31 +96,30 @@ class DDIMSampler:
x0=None, x0=None,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
x_T=None, x_T=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
**kwargs, **kwargs,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
): ):
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
logger.warning(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
logger.warning( logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
) )
schedule = DDIMSchedule(
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta) model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
model_betas=self.model.betas,
ddim_num_steps=num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
)
samples = self.ddim_sampling( samples = self.ddim_sampling(
conditioning, conditioning,
shape=(batch_size, *shape), shape=shape,
schedule=schedule,
callback=callback, callback=callback,
img_callback=img_callback, img_callback=img_callback,
quantize_denoised=quantize_x0, quantize_denoised=quantize_x0,
@ -155,8 +127,6 @@ class DDIMSampler:
x0=x0, x0=x0,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
temperature=temperature, temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T, x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
@ -168,6 +138,7 @@ class DDIMSampler:
self, self,
cond, cond,
shape, shape,
schedule,
x_T=None, x_T=None,
callback=None, callback=None,
timesteps=None, timesteps=None,
@ -177,8 +148,6 @@ class DDIMSampler:
img_callback=None, img_callback=None,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
): ):
@ -192,16 +161,16 @@ class DDIMSampler:
log_latent(img, "initial noise") log_latent(img, "initial noise")
if timesteps is None: if timesteps is None:
timesteps = self.ddim_timesteps timesteps = schedule.ddim_timesteps
else: else:
subset_end = ( subset_end = (
int( int(
min(timesteps / self.ddim_timesteps.shape[0], 1) min(timesteps / schedule.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0] * schedule.ddim_timesteps.shape[0]
) )
- 1 - 1
) )
timesteps = self.ddim_timesteps[:subset_end] timesteps = schedule.ddim_timesteps[:subset_end]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
@ -225,6 +194,7 @@ class DDIMSampler:
cond, cond,
ts, ts,
index=index, index=index,
schedule=schedule,
quantize_denoised=quantize_denoised, quantize_denoised=quantize_denoised,
temperature=temperature, temperature=temperature,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
@ -245,6 +215,7 @@ class DDIMSampler:
c, c,
t, t,
index, index,
schedule,
repeat_noise=False, repeat_noise=False,
quantize_denoised=False, quantize_denoised=False,
temperature=1.0, temperature=1.0,
@ -254,26 +225,26 @@ class DDIMSampler:
loss_function=None, loss_function=None,
): ):
assert unconditional_guidance_scale >= 1 assert unconditional_guidance_scale >= 1
x_in = torch.cat([x] * 2) noise_pred = get_noise_prediction(
t_in = torch.cat([t] * 2) denoise_func=self.model.apply_model,
c_in = torch.cat([unconditional_conditioning, c]) noisy_latent=x,
# with torch.no_grad(): time_encoding=t,
noise_pred_uncond, noise_pred = self.model.apply_model(x_in, t_in, c_in).chunk( neutral_conditioning=unconditional_conditioning,
2 positive_conditioning=c,
) signal_amplification=unconditional_guidance_scale,
noise_pred = noise_pred_uncond + unconditional_guidance_scale * (
noise_pred - noise_pred_uncond
) )
b = x.shape[0] b = x.shape[0]
log_latent(noise_pred, "noise prediction") log_latent(noise_pred, "noise prediction")
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), self.ddim_alphas[index], device=x.device) a_t = torch.full((b, 1, 1, 1), schedule.ddim_alphas[index], device=x.device)
a_prev = torch.full((b, 1, 1, 1), self.ddim_alphas_prev[index], device=x.device) a_prev = torch.full(
sigma_t = torch.full((b, 1, 1, 1), self.ddim_sigmas[index], device=x.device) (b, 1, 1, 1), schedule.ddim_alphas_prev[index], device=x.device
)
sigma_t = torch.full((b, 1, 1, 1), schedule.ddim_sigmas[index], device=x.device)
sqrt_one_minus_at = torch.full( sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), self.ddim_sqrt_one_minus_alphas[index], device=x.device (b, 1, 1, 1), schedule.ddim_sqrt_one_minus_alphas[index], device=x.device
) )
return self._p_sample_ddim_formula( return self._p_sample_ddim_formula(
x, x,
@ -310,12 +281,11 @@ class DDIMSampler:
return x_prev, pred_x0 return x_prev, pred_x0
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, init_latent, t, noise=None): def noise_an_image(self, init_latent, t, schedule, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas # t serves as an index to gather the correct alphas
t = t.clamp(0, 1000) t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None: if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device()) noise = torch.randn_like(init_latent, device="cpu").to(get_device())
@ -328,31 +298,34 @@ class DDIMSampler:
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
x_latent, initial_latent,
cond, cond,
t_start, t_start,
schedule,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
img_callback=None, img_callback=None,
score_corrector=None,
temperature=1.0, temperature=1.0,
mask=None, mask=None,
orig_latent=None, orig_latent=None,
): ):
timesteps = self.ddim_timesteps[:t_start] timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
logger.debug(f"Running DDIM Sampling with {total_steps} timesteps") logger.debug(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="Decoding image", total=total_steps) iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent x_dec = initial_latent
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full( ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long (initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
) )
if mask is not None: if mask is not None:
@ -374,17 +347,12 @@ class DDIMSampler:
x_dec, x_dec,
cond, cond,
ts, ts,
schedule=schedule,
index=index, index=index,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
temperature=temperature, temperature=temperature,
) )
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
log_latent(x_dec, f"x_dec {i}") log_latent(x_dec, f"x_dec {i}")
log_latent(pred_x0, f"pred_x0 {i}") log_latent(pred_x0, f"pred_x0 {i}")

View File

@ -8,11 +8,15 @@ from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
class StandardCompVisDenoiser(CompVisDenoiser):
def apply_model(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, model, sampler_name): def __init__(self, model, sampler_name):
self.model = model self.model = model
self.cv_denoiser = CompVisDenoiser(model) self.cv_denoiser = StandardCompVisDenoiser(model)
# self.cfg_denoiser = CompVisDenoiser(self.cv_denoiser)
self.sampler_name = sampler_name self.sampler_name = sampler_name
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}") self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
@ -28,10 +32,8 @@ class KDiffusionSampler:
initial_noise_tensor=None, initial_noise_tensor=None,
img_callback=None, img_callback=None,
): ):
size = (batch_size, *shape)
initial_noise_tensor = ( initial_noise_tensor = (
torch.randn(size, device="cpu").to(get_device()) torch.randn(shape, device="cpu").to(get_device())
if initial_noise_tensor is None if initial_noise_tensor is None
else initial_noise_tensor else initial_noise_tensor
) )

View File

@ -1,5 +1,4 @@
# pylama:ignore=W0613 # pylama:ignore=W0613
"""SAMPLING ONLY."""
import logging import logging
import numpy as np import numpy as np
@ -13,65 +12,52 @@ from imaginairy.modules.diffusion.util import (
make_ddim_timesteps, make_ddim_timesteps,
noise_like, noise_like,
) )
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.utils import get_device from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PLMSSampler: class PLMSSchedule:
"""probabilistic least-mean-squares""" def __init__(
self,
def __init__(self, model): ddpm_num_timesteps, # 1000?
self.model = model ddim_num_steps, # prompt.steps?
self.ddpm_num_timesteps = model.num_timesteps alphas_cumprod,
self.device_available = get_device() alphas_cumprod_prev,
self.ddim_timesteps = None betas,
ddim_discretize="uniform",
def register_buffer(self, name, attr): ddim_eta=0.0,
if isinstance(attr, torch.Tensor): ):
if attr.device != torch.device(self.device_available):
attr = attr.to(torch.float32).to(torch.device(self.device_available))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0):
if ddim_eta != 0: if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS") raise ValueError("ddim_eta must be 0 for PLMS")
self.ddim_timesteps = make_ddim_timesteps( device = get_device()
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
)
alphas_cumprod = self.model.alphas_cumprod
assert ( assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps alphas_cumprod.shape[0] == ddpm_num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
def to_torch(x): def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device) return x.clone().detach().to(torch.float32).to(device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
self.betas = to_torch(betas)
self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer( self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
) )
self.register_buffer( self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
"sqrt_one_minus_alphas_cumprod", self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
) )
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) self.ddim_timesteps = make_ddim_timesteps(
) ddim_discr_method=ddim_discretize,
self.register_buffer( num_ddim_timesteps=ddim_num_steps,
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) num_ddpm_timesteps=ddpm_num_timesteps,
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
) )
# ddim sampling parameters # ddim sampling parameters
@ -80,19 +66,21 @@ class PLMSSampler:
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, eta=ddim_eta,
) )
self.register_buffer("ddim_sigmas", ddim_sigmas) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
self.register_buffer("ddim_alphas", ddim_alphas) self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) self.ddim_alphas_prev = ddim_alphas_prev
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) self.ddim_sqrt_one_minus_alphas = (
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(torch.device(device))
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
) )
class PLMSSampler:
"""probabilistic least-mean-squares"""
def __init__(self, model):
self.model = model
self.device = get_device()
@torch.no_grad() @torch.no_grad()
def sample( def sample(
self, self,
@ -101,145 +89,90 @@ class PLMSSampler:
shape, shape,
conditioning=None, conditioning=None,
callback=None, callback=None,
normals_sequence=None,
img_callback=None, img_callback=None,
quantize_x0=False, quantize_x0=False,
eta=0.0, eta=0.0,
mask=None, mask=None,
x0=None, orig_latent=None,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None, initial_latent=None,
corrector_kwargs=None,
x_T=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
timesteps=None,
quantize_denoised=False,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs, **kwargs,
): ):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
logger.warning(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size: if conditioning.shape[0] != batch_size:
logger.warning( logger.warning(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
) )
self.make_schedule(ddim_num_steps=num_steps, ddim_eta=eta) schedule = PLMSSchedule(
ddpm_num_timesteps=self.model.num_timesteps,
samples = self.plms_sampling( ddim_num_steps=num_steps,
conditioning, alphas_cumprod=self.model.alphas_cumprod,
(batch_size, *shape), alphas_cumprod_prev=self.model.alphas_cumprod_prev,
callback=callback, betas=self.model.betas,
img_callback=img_callback, ddim_discretize="uniform",
quantize_denoised=quantize_x0, ddim_eta=0.0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
) )
return samples device = self.device
# batch_size = shape[0]
@torch.no_grad() if initial_latent is None:
def plms_sampling( initial_latent = torch.randn(shape, device="cpu").to(device)
self, log_latent(initial_latent, "initial latent")
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device="cpu").to(device)
else:
img = x_T
log_latent(img, "initial img")
if timesteps is None: if timesteps is None:
timesteps = ( timesteps = schedule.ddim_timesteps
self.ddpm_num_timesteps elif timesteps is not None:
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = ( subset_end = (
int( int(
min(timesteps / self.ddim_timesteps.shape[0], 1) min(timesteps / schedule.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0] * schedule.ddim_timesteps.shape[0]
) )
- 1 - 1
) )
timesteps = self.ddim_timesteps[:subset_end] timesteps = schedule.ddim_timesteps[:subset_end]
time_range = ( time_range = np.flip(timesteps)
list(reversed(range(0, timesteps))) total_steps = timesteps.shape[0]
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.debug(f"Running PLMS Sampling with {total_steps} timesteps") logger.debug(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps) iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps)
old_eps = [] old_eps = []
img = initial_latent
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long) ts = torch.full((batch_size,), step, device=device, dtype=torch.long)
ts_next = torch.full( ts_next = torch.full(
(b,), (batch_size,),
time_range[min(i + 1, len(time_range) - 1)], time_range[min(i + 1, len(time_range) - 1)],
device=device, device=device,
dtype=torch.long, dtype=torch.long,
) )
if mask is not None: if mask is not None:
assert x0 is not None assert orig_latent is not None
img_orig = self.model.q_sample( img_orig = self.model.q_sample(orig_latent, ts)
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img img = img_orig * mask + (1.0 - mask) * img
img, pred_x0, e_t = self.p_sample_plms( img, pred_x0, noise_prediction = self.p_sample_plms(
img, img,
cond, conditioning,
ts, ts,
schedule=schedule,
index=index, index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, quantize_denoised=quantize_denoised,
temperature=temperature, temperature=temperature,
noise_dropout=noise_dropout, noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, old_eps=old_eps,
t_next=ts_next, t_next=ts_next,
) )
old_eps.append(e_t) old_eps.append(noise_prediction)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)
if callback: if callback:
@ -253,119 +186,108 @@ class PLMSSampler:
@torch.no_grad() @torch.no_grad()
def p_sample_plms( def p_sample_plms(
self, self,
x, noisy_latent,
c, positive_conditioning,
t, time_encoding,
schedule: PLMSSchedule,
index, index,
repeat_noise=False, repeat_noise=False,
use_original_steps=False,
quantize_denoised=False, quantize_denoised=False,
temperature=1.0, temperature=1.0,
noise_dropout=0.0, noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
old_eps=None, old_eps=None,
t_next=None, t_next=None,
): ):
b, *_, device = *x.shape, x.device batch_size = noisy_latent.shape[0]
noise_prediction = get_noise_prediction(
def get_model_output(x, t): denoise_func=self.model.apply_model,
if ( noisy_latent=noisy_latent,
unconditional_conditioning is None time_encoding=time_encoding,
or unconditional_guidance_scale == 1.0 neutral_conditioning=unconditional_conditioning,
): positive_conditioning=positive_conditioning,
e_t = self.model.apply_model(x, t, c) signal_amplification=unconditional_guidance_scale,
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
log_latent(e_t_uncond, "noise pred uncond")
log_latent(e_t, "noise pred cond")
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
log_latent(e_t, "noise pred combined")
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
) )
def get_x_prev_and_pred_x0(e_t, index): def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep # select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) alpha_at_t = torch.full(
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) (batch_size, 1, 1, 1), schedule.ddim_alphas[index], device=self.device
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) )
alpha_prev_at_t = torch.full(
(batch_size, 1, 1, 1),
schedule.ddim_alphas_prev[index],
device=self.device,
)
sigma_t = torch.full(
(batch_size, 1, 1, 1), schedule.ddim_sigmas[index], device=self.device
)
sqrt_one_minus_at = torch.full( sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device (batch_size, 1, 1, 1),
schedule.ddim_sqrt_one_minus_alphas[index],
device=self.device,
) )
# current prediction for x_0 # current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (noisy_latent - sqrt_one_minus_at * e_t) / alpha_at_t.sqrt()
if quantize_denoised: if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t # direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t dir_xt = (1.0 - alpha_prev_at_t - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature noise = (
sigma_t
* noise_like(noisy_latent.shape, self.device, repeat_noise)
* temperature
)
if noise_dropout > 0.0: if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout) noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise x_prev = alpha_prev_at_t.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0 return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0: if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order) # Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_prediction, index)
e_t_next = get_model_output(x_prev, t_next) e_t_next = get_noise_prediction(
e_t_prime = (e_t + e_t_next) / 2 denoise_func=self.model.apply_model,
noisy_latent=x_prev,
time_encoding=t_next,
neutral_conditioning=unconditional_conditioning,
positive_conditioning=positive_conditioning,
signal_amplification=unconditional_guidance_scale,
)
e_t_prime = (noise_prediction + e_t_next) / 2
elif len(old_eps) == 1: elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth) # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2 e_t_prime = (3 * noise_prediction - old_eps[-1]) / 2
elif len(old_eps) == 2: elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth) # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 e_t_prime = (
23 * noise_prediction - 16 * old_eps[-1] + 5 * old_eps[-2]
) / 12
elif len(old_eps) >= 3: elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = ( e_t_prime = (
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] 55 * noise_prediction
- 59 * old_eps[-1]
+ 37 * old_eps[-2]
- 9 * old_eps[-3]
) / 24 ) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
log_latent(x_prev, "x_prev") log_latent(x_prev, "x_prev")
log_latent(pred_x0, "pred_x0") log_latent(pred_x0, "pred_x0")
return x_prev, pred_x0, e_t return x_prev, pred_x0, noise_prediction
@torch.no_grad() @torch.no_grad()
def stochastic_encode(self, init_latent, t, noise=None): def noise_an_image(self, init_latent, t, schedule, noise=None):
# replace with ddpm.q_sample?
# fast, but does not allow for exact reconstruction # fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas # t serves as an index to gather the correct alphas
t = t.clamp(0, 1000) t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None: if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device()) noise = torch.randn_like(init_latent, device="cpu").to(get_device())
@ -378,26 +300,26 @@ class PLMSSampler:
@torch.no_grad() @torch.no_grad()
def decode( def decode(
self, self,
x_latent,
cond, cond,
t_start, schedule,
initial_latent=None,
t_start=None,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
img_callback=None, img_callback=None,
score_corrector=None,
temperature=1.0, temperature=1.0,
mask=None, mask=None,
orig_latent=None, orig_latent=None,
noise=None, noise=None,
): ):
device = self.device
timesteps = self.ddim_timesteps[:t_start] timesteps = schedule.ddim_timesteps[:t_start]
time_range = np.flip(timesteps) time_range = np.flip(timesteps)
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps) iterator = tqdm(time_range, desc="PLMS img2img", total=total_steps)
x_dec = x_latent x_dec = initial_latent
old_eps = [] old_eps = []
log_latent(x_dec, "x_dec") log_latent(x_dec, "x_dec")
@ -411,12 +333,15 @@ class PLMSSampler:
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full( ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long (initial_latent.shape[0],),
step,
device=initial_latent.device,
dtype=torch.long,
) )
ts_next = torch.full( ts_next = torch.full(
(x_latent.shape[0],), (initial_latent.shape[0],),
time_range[min(i + 1, len(time_range) - 1)], time_range[min(i + 1, len(time_range) - 1)],
device=x_latent.device, device=device,
dtype=torch.long, dtype=torch.long,
) )
@ -435,10 +360,11 @@ class PLMSSampler:
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, f"x_dec {ts}") log_latent(x_dec, f"x_dec {ts}")
x_dec, pred_x0, e_t = self.p_sample_plms( x_dec, pred_x0, noise_prediction = self.p_sample_plms(
x_dec, x_dec,
cond, cond,
ts, ts,
schedule=schedule,
index=index, index=index,
unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning, unconditional_conditioning=unconditional_conditioning,
@ -446,14 +372,8 @@ class PLMSSampler:
old_eps=old_eps, old_eps=old_eps,
t_next=ts_next, t_next=ts_next,
) )
# original_loss = ((x_dec - x_latent).abs().mean()*70)
# sigma_t = torch.full((1, 1, 1, 1), self.ddim_sigmas[index], device=get_device())
# x_dec = x_dec.detach() + (original_loss * 0.1) ** 2
# cond_grad = -torch.autograd.grad(original_loss, x_dec)[0]
# x_dec = x_dec.detach() + cond_grad * sigma_t ** 2
# x_dec_alt = x_dec + (original_loss * 0.1) ** 2
old_eps.append(e_t) old_eps.append(noise_prediction)
if len(old_eps) >= 4: if len(old_eps) >= 4:
old_eps.pop(0) old_eps.pop(0)

View File

@ -47,6 +47,11 @@ def pre_setup():
yield yield
@pytest.fixture(autouse=True)
def reset_get_device():
get_device.cache_clear()
@pytest.fixture() @pytest.fixture()
def filename_base_for_outputs(request): def filename_base_for_outputs(request):
filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_" filename_base = f"{TESTS_FOLDER}/test_output/{request.node.name}_{get_device()}_"

View File

@ -54,7 +54,6 @@ def experiment_step_repeats():
embedder.to(get_device()) embedder.to(get_device())
sampler = DDIMSampler(model) sampler = DDIMSampler(model)
sampler.make_schedule(1000)
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
init_image, _, _ = pillow_img_to_torch_image( init_image, _, _ = pillow_img_to_torch_image(
@ -89,7 +88,9 @@ def experiment_step_repeats():
# noise_pred = model.apply_model(init_latent, t, neutral_embedding) # noise_pred = model.apply_model(init_latent, t, neutral_embedding)
# log_latent(noise_pred, "noise prediction") # log_latent(noise_pred, "noise prediction")
for _ in range(100): for _ in range(100):
x_prev, pred_x0 = sampler.p_sample_ddim(x_prev, neutral_embedding, t, index) x_prev, pred_x0 = sampler.p_sample_ddim( # noqa
x_prev, neutral_embedding, t, index
)
log_latent(pred_x0, "pred_x0") log_latent(pred_x0, "pred_x0")
x_prev = pred_x0 x_prev = pred_x0

View File

@ -12,7 +12,6 @@ from imaginairy.utils import (
get_hardware_description, get_hardware_description,
get_obj_from_str, get_obj_from_str,
instantiate_from_config, instantiate_from_config,
platform_appropriate_autocast,
) )
@ -79,6 +78,7 @@ def test_instantiate_from_config():
instantiate_from_config(config) instantiate_from_config(config)
def test_platform_appropriate_autocast(): #
with platform_appropriate_autocast("autocast"): # def test_platform_appropriate_autocast():
pass # with platform_appropriate_autocast("autocast"):
# pass