From 8d4b5cb9e1884a95d3a6f54352d61425d36a4181 Mon Sep 17 00:00:00 2001 From: Bryce Date: Wed, 12 Oct 2022 22:32:17 -0700 Subject: [PATCH] refactor: standardize samplers more --- imaginairy/api.py | 20 +- imaginairy/modules/diffusion/util.py | 1 - imaginairy/samplers/base.py | 20 +- imaginairy/samplers/ddim.py | 269 +++++++++++---------------- imaginairy/samplers/kdiff.py | 46 +++-- imaginairy/samplers/plms.py | 177 ++++++++---------- 6 files changed, 233 insertions(+), 300 deletions(-) diff --git a/imaginairy/api.py b/imaginairy/api.py index dd9274a..f32f230 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -170,7 +170,7 @@ def imagine( # torch.set_default_tensor_type(torch.HalfTensor) prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts - _img_callback = None + if get_device() == "cpu": logger.info("Running in CPU mode. it's gonna be slooooooow.") @@ -279,9 +279,9 @@ def imagine( noise = torch.randn_like(init_latent, device="cpu").to(get_device()) # todo: this isn't the right scheduler for everything... schedule = PLMSSchedule( - ddpm_num_timesteps=model.num_timesteps, + model_num_timesteps=model.num_timesteps, ddim_num_steps=prompt.steps, - alphas_cumprod=model.alphas_cumprod, + model_alphas_cumprod=model.alphas_cumprod, ddim_discretize="uniform", ) if generation_strength >= 1: @@ -301,12 +301,11 @@ def imagine( # decode it samples = sampler.decode( initial_latent=z_enc, - cond=c, + positive_conditioning=c, t_start=t_enc, schedule=schedule, - unconditional_guidance_scale=prompt.prompt_strength, - unconditional_conditioning=uc, - img_callback=_img_callback, + guidance_scale=prompt.prompt_strength, + neutral_conditioning=uc, mask=mask, orig_latent=init_latent, ) @@ -314,12 +313,11 @@ def imagine( samples = sampler.sample( num_steps=prompt.steps, - conditioning=c, + positive_conditioning=c, batch_size=1, shape=shape, - unconditional_guidance_scale=prompt.prompt_strength, - unconditional_conditioning=uc, - img_callback=_img_callback, + guidance_scale=prompt.prompt_strength, + neutral_conditioning=uc, ) x_samples = model.decode_first_stage(samples) diff --git a/imaginairy/modules/diffusion/util.py b/imaginairy/modules/diffusion/util.py index ef061f4..b76b978 100644 --- a/imaginairy/modules/diffusion/util.py +++ b/imaginairy/modules/diffusion/util.py @@ -87,7 +87,6 @@ def make_ddim_timesteps( # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 - logger.debug(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 81015aa..b34957d 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -56,7 +56,7 @@ class CFGDenoiser(nn.Module): noisy_latent_in, time_encoding_in, cond=conditioning_in ) - denoised = get_noise_prediction( + noise_pred = get_noise_prediction( denoise_func=_wrapper, noisy_latent=x, time_encoding=sigma, @@ -68,9 +68,9 @@ class CFGDenoiser(nn.Module): if mask is not None: assert orig_latent is not None mask_inv = 1.0 - mask - denoised = (orig_latent * mask_inv) + (mask * denoised) + noise_pred = (orig_latent * mask_inv) + (mask * noise_pred) - return denoised + return noise_pred def ensure_4_dim(t: torch.Tensor): @@ -93,17 +93,17 @@ def get_noise_prediction( time_encoding_in = torch.cat([time_encoding] * 2) conditioning_in = torch.cat([neutral_conditioning, positive_conditioning]) - pred_noise_neutral, pred_noise_positive = denoise_func( + noise_pred_neutral, noise_pred_positive = denoise_func( noisy_latent_in, time_encoding_in, conditioning_in ).chunk(2) amplified_noise_pred = signal_amplification * ( - pred_noise_positive - pred_noise_neutral + noise_pred_positive - noise_pred_neutral ) - pred_noise = pred_noise_neutral + amplified_noise_pred + noise_pred = noise_pred_neutral + amplified_noise_pred - log_latent(pred_noise_neutral, "neutral noise prediction") - log_latent(pred_noise_positive, "positive noise prediction") - log_latent(pred_noise, "noise prediction") + log_latent(noise_pred_neutral, "noise_pred_neutral") + log_latent(noise_pred_positive, "noise_pred_positive") + log_latent(noise_pred, "noise_pred") - return pred_noise + return noise_pred diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 4c1782a..5e32b6c 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -18,6 +18,10 @@ from imaginairy.utils import get_device logger = logging.getLogger(__name__) +def to_torch(x): + return x.clone().detach().to(torch.float32).to(get_device()) + + class DDIMSchedule: def __init__( self, @@ -26,32 +30,31 @@ class DDIMSchedule: ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, - device=get_device(), ): + device = get_device() + if not model_alphas_cumprod.shape[0] == model_num_timesteps: + raise ValueError("alphas have to be defined for each timestep") + + self.alphas_cumprod = to_torch(model_alphas_cumprod) + ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=model_num_timesteps, ) - alphas_cumprod = model_alphas_cumprod - if not alphas_cumprod.shape[0] == model_num_timesteps: - raise ValueError("alphas have to be defined for each timestep") - - def to_torch(x): - return x.clone().detach().to(torch.float32).to(device) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), + alphacums=model_alphas_cumprod.cpu(), ddim_timesteps=ddim_timesteps, eta=ddim_eta, ) self.ddim_timesteps = ddim_timesteps - self.alphas_cumprod = to_torch(alphas_cumprod) + # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) + self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu())) self.sqrt_one_minus_alphas_cumprod = to_torch( - np.sqrt(1.0 - alphas_cumprod.cpu()) + np.sqrt(1.0 - model_alphas_cumprod.cpu()) ) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device) self.ddim_alphas = ddim_alphas.to(torch.float32).to(device) @@ -70,186 +73,138 @@ class DDIMSampler: def __init__(self, model): self.model = model + self.device = get_device() @torch.no_grad() def sample( self, num_steps, - batch_size, shape, - conditioning, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0.0, + neutral_conditioning, + positive_conditioning, + guidance_scale=1.0, + batch_size=1, mask=None, - x0=None, + orig_latent=None, temperature=1.0, noise_dropout=0.0, - x_T=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - **kwargs, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + initial_latent=None, + quantize_x0=False, ): - if conditioning.shape[0] != batch_size: - logger.warning( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + if positive_conditioning.shape[0] != batch_size: + raise ValueError( + f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}" ) + schedule = DDIMSchedule( model_num_timesteps=self.model.num_timesteps, model_alphas_cumprod=self.model.alphas_cumprod, ddim_num_steps=num_steps, ddim_discretize="uniform", - ddim_eta=0.0, ) - samples = self.ddim_sampling( - conditioning, - shape=shape, - schedule=schedule, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, - x0=x0, - noise_dropout=noise_dropout, - temperature=temperature, - x_T=x_T, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples + if initial_latent is None: + initial_latent = torch.randn(shape, device="cpu").to(self.device) - @torch.no_grad() - def ddim_sampling( - self, - cond, - shape, - schedule, - x_T=None, - callback=None, - timesteps=None, - quantize_denoised=False, - mask=None, - x0=None, - img_callback=None, - temperature=1.0, - noise_dropout=0.0, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - ): - device = self.model.betas.device - b = shape[0] - if x_T is None: - # run on CPU for seed consistency. M1/mps runs were not consistent otherwise - img = torch.randn(shape, device="cpu").to(device) - else: - img = x_T - log_latent(img, "initial noise") - - if timesteps is None: - timesteps = schedule.ddim_timesteps - else: - subset_end = ( - int( - min(timesteps / schedule.ddim_timesteps.shape[0], 1) - * schedule.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = schedule.ddim_timesteps[:subset_end] + log_latent(initial_latent, "initial latent") + + timesteps = schedule.ddim_timesteps time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - logger.info(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + noisy_latent = initial_latent - for i, step in enumerate(iterator): + for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) + ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long) if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample( - x0, ts - ) # TODO: deterministic forward pass? - img = img_orig * mask + (1.0 - mask) * img - - img, pred_x0 = self.p_sample_ddim( - img, - cond, - ts, + assert orig_latent is not None + img_orig = self.model.q_sample(orig_latent, ts) + noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent + + noisy_latent, predicted_latent = self.p_sample_ddim( + noisy_latent=noisy_latent, + neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning, + guidance_scale=guidance_scale, + time_encoding=ts, index=index, schedule=schedule, - quantize_denoised=quantize_denoised, + quantize_denoised=quantize_x0, temperature=temperature, noise_dropout=noise_dropout, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, ) - if callback: - callback(i) - log_latent(img, "img") - log_latent(pred_x0, "pred_x0") + log_latent(noisy_latent, "noisy_latent") + log_latent(predicted_latent, "predicted_latent") - return img + return noisy_latent def p_sample_ddim( self, - x, - c, - t, + noisy_latent, + neutral_conditioning, + positive_conditioning, + guidance_scale, + time_encoding, index, schedule, repeat_noise=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, loss_function=None, ): - assert unconditional_guidance_scale >= 1 + assert guidance_scale >= 1 noise_pred = get_noise_prediction( denoise_func=self.model.apply_model, - noisy_latent=x, - time_encoding=t, - neutral_conditioning=unconditional_conditioning, - positive_conditioning=c, - signal_amplification=unconditional_guidance_scale, + noisy_latent=noisy_latent, + time_encoding=time_encoding, + neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning, + signal_amplification=guidance_scale, ) - b = x.shape[0] - log_latent(noise_pred, "noise prediction") + batch_size = noisy_latent.shape[0] # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), schedule.ddim_alphas[index], device=x.device) + a_t = torch.full( + (batch_size, 1, 1, 1), + schedule.ddim_alphas[index], + device=noisy_latent.device, + ) a_prev = torch.full( - (b, 1, 1, 1), schedule.ddim_alphas_prev[index], device=x.device + (batch_size, 1, 1, 1), + schedule.ddim_alphas_prev[index], + device=noisy_latent.device, + ) + sigma_t = torch.full( + (batch_size, 1, 1, 1), + schedule.ddim_sigmas[index], + device=noisy_latent.device, ) - sigma_t = torch.full((b, 1, 1, 1), schedule.ddim_sigmas[index], device=x.device) sqrt_one_minus_at = torch.full( - (b, 1, 1, 1), schedule.ddim_sqrt_one_minus_alphas[index], device=x.device + (batch_size, 1, 1, 1), + schedule.ddim_sqrt_one_minus_alphas[index], + device=noisy_latent.device, ) - return self._p_sample_ddim_formula( - x, - noise_pred, - sqrt_one_minus_at, - a_t, - sigma_t, - a_prev, - noise_dropout, - repeat_noise, - temperature, + noisy_latent, predicted_latent = self._p_sample_ddim_formula( + noisy_latent=noisy_latent, + noise_pred=noise_pred, + sqrt_one_minus_at=sqrt_one_minus_at, + a_t=a_t, + sigma_t=sigma_t, + a_prev=a_prev, + noise_dropout=noise_dropout, + repeat_noise=repeat_noise, + temperature=temperature, ) + return noisy_latent, predicted_latent @staticmethod def _p_sample_ddim_formula( - x, + noisy_latent, noise_pred, sqrt_one_minus_at, a_t, @@ -259,15 +214,18 @@ class DDIMSampler: repeat_noise, temperature, ): - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * noise_pred) / a_t.sqrt() + predicted_latent = (noisy_latent - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * noise_pred - noise = sigma_t * noise_like(x.shape, x.device, repeat_noise) * temperature + noise = ( + sigma_t + * noise_like(noisy_latent.shape, noisy_latent.device, repeat_noise) + * temperature + ) if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 + x_prev = a_prev.sqrt() * predicted_latent + dir_xt + noise + return x_prev, predicted_latent @torch.no_grad() def noise_an_image(self, init_latent, t, schedule, noise=None): @@ -288,12 +246,11 @@ class DDIMSampler: def decode( self, initial_latent, - cond, + neutral_conditioning, + positive_conditioning, + guidance_scale, t_start, schedule, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - img_callback=None, temperature=1.0, mask=None, orig_latent=None, @@ -303,12 +260,10 @@ class DDIMSampler: time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - logger.debug(f"Running DDIM Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc="Decoding image", total=total_steps) - x_dec = initial_latent + noisy_latent = initial_latent - for i, step in enumerate(iterator): + for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 ts = torch.full( (initial_latent.shape[0],), @@ -329,20 +284,20 @@ class DDIMSampler: ) else: xdec_orig_with_hints = xdec_orig - x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec - log_latent(x_dec, "x_dec") + noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent + log_latent(noisy_latent, "noisy_latent") - x_dec, pred_x0 = self.p_sample_ddim( - x_dec, - cond, - ts, + noisy_latent, predicted_latent = self.p_sample_ddim( + noisy_latent=noisy_latent, + positive_conditioning=positive_conditioning, + time_encoding=ts, schedule=schedule, index=index, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, + guidance_scale=guidance_scale, + neutral_conditioning=neutral_conditioning, temperature=temperature, ) - log_latent(x_dec, f"x_dec {i}") - log_latent(pred_x0, f"pred_x0 {i}") - return x_dec + log_latent(noisy_latent, f"noisy_latent {i}") + log_latent(predicted_latent, f"predicted_latent {i}") + return noisy_latent diff --git a/imaginairy/samplers/kdiff.py b/imaginairy/samplers/kdiff.py index 7c0b936..33a906a 100644 --- a/imaginairy/samplers/kdiff.py +++ b/imaginairy/samplers/kdiff.py @@ -19,43 +19,49 @@ class KDiffusionSampler: self.cv_denoiser = StandardCompVisDenoiser(model) self.sampler_name = sampler_name self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}") + self.device = get_device() def sample( self, num_steps, - conditioning, - batch_size, shape, - unconditional_guidance_scale, - unconditional_conditioning, - initial_noise_tensor=None, + neutral_conditioning, + positive_conditioning, + guidance_scale, + batch_size=1, + mask=None, + orig_latent=None, + initial_latent=None, img_callback=None, ): - initial_noise_tensor = ( - torch.randn(shape, device="cpu").to(get_device()) - if initial_noise_tensor is None - else initial_noise_tensor - ) - log_latent(initial_noise_tensor, "initial_noise_tensor") + if positive_conditioning.shape[0] != batch_size: + raise ValueError( + f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + if initial_latent is None: + initial_latent = torch.randn(shape, device="cpu").to(self.device) + + log_latent(initial_latent, "initial_latent") sigmas = self.cv_denoiser.get_sigmas(num_steps) - x = initial_noise_tensor * sigmas[0] + x = initial_latent * sigmas[0] log_latent(x, "initial_sigma_noised_tensor") model_wrap_cfg = CFGDenoiser(self.cv_denoiser) def callback(data): - log_latent(data["x"], "x") - log_latent(data["denoised"], "denoised") + log_latent(data["x"], "noisy_latent") + log_latent(data["denoised"], "noise_pred") samples = self.sampler_func( - model_wrap_cfg, - x, - sigmas, + model=model_wrap_cfg, + x=x, + sigmas=sigmas, extra_args={ - "cond": conditioning, - "uncond": unconditional_conditioning, - "cond_scale": unconditional_guidance_scale, + "cond": positive_conditioning, + "uncond": neutral_conditioning, + "cond_scale": guidance_scale, }, disable=False, callback=callback, diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index a4c5912..b71e23e 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -18,39 +18,38 @@ from imaginairy.utils import get_device logger = logging.getLogger(__name__) +def to_torch(x): + return x.clone().detach().to(torch.float32).to(get_device()) + + class PLMSSchedule: def __init__( self, - ddpm_num_timesteps, # 1000? + model_num_timesteps, # 1000? + model_alphas_cumprod, ddim_num_steps, # prompt.steps? - alphas_cumprod, ddim_discretize="uniform", ): device = get_device() + if model_alphas_cumprod.shape[0] != model_num_timesteps: + raise ValueError("alphas have to be defined for each timestep") - assert ( - alphas_cumprod.shape[0] == ddpm_num_timesteps - ), "alphas have to be defined for each timestep" - - def to_torch(x): - return x.clone().detach().to(torch.float32).to(device) - - self.alphas_cumprod = to_torch(alphas_cumprod) + self.alphas_cumprod = to_torch(model_alphas_cumprod) # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) + self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu())) self.sqrt_one_minus_alphas_cumprod = to_torch( - np.sqrt(1.0 - alphas_cumprod.cpu()) + np.sqrt(1.0 - model_alphas_cumprod.cpu()) ) self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=ddpm_num_timesteps, + num_ddpm_timesteps=model_num_timesteps, ) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( - alphacums=alphas_cumprod.cpu(), + alphacums=model_alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=0.0, ) @@ -63,7 +62,14 @@ class PLMSSchedule: class PLMSSampler: - """probabilistic least-mean-squares""" + """ + probabilistic least-mean-squares + + Provenance: + https://github.com/CompVis/latent-diffusion/commit/f0c4e092c156986e125f48c61a0edd38ba8ad059 + https://arxiv.org/abs/2202.09778 + https://github.com/luping-liu/PNDM + """ def __init__(self, model): self.model = model @@ -73,106 +79,89 @@ class PLMSSampler: def sample( self, num_steps, - batch_size, shape, - conditioning=None, - callback=None, - img_callback=None, - quantize_x0=False, - eta=0.0, + neutral_conditioning, + positive_conditioning, + guidance_scale=1.0, + batch_size=1, mask=None, orig_latent=None, temperature=1.0, noise_dropout=0.0, initial_latent=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - timesteps=None, quantize_denoised=False, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): - if conditioning.shape[0] != batch_size: - logger.warning( - f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + if positive_conditioning.shape[0] != batch_size: + raise ValueError( + f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}" ) schedule = PLMSSchedule( - ddpm_num_timesteps=self.model.num_timesteps, + model_num_timesteps=self.model.num_timesteps, ddim_num_steps=num_steps, - alphas_cumprod=self.model.alphas_cumprod, + model_alphas_cumprod=self.model.alphas_cumprod, ddim_discretize="uniform", ) - device = self.device - # batch_size = shape[0] + if initial_latent is None: - initial_latent = torch.randn(shape, device="cpu").to(device) + initial_latent = torch.randn(shape, device="cpu").to(self.device) + log_latent(initial_latent, "initial latent") - if timesteps is None: - timesteps = schedule.ddim_timesteps - elif timesteps is not None: - subset_end = ( - int( - min(timesteps / schedule.ddim_timesteps.shape[0], 1) - * schedule.ddim_timesteps.shape[0] - ) - - 1 - ) - timesteps = schedule.ddim_timesteps[:subset_end] + + timesteps = schedule.ddim_timesteps time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - logger.debug(f"Running PLMS Sampling with {total_steps} timesteps") - iterator = tqdm(time_range, desc=" PLMS Sampler", total=total_steps) old_eps = [] - img = initial_latent + noisy_latent = initial_latent - for i, step in enumerate(iterator): + for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 - ts = torch.full((batch_size,), step, device=device, dtype=torch.long) + ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long) ts_next = torch.full( (batch_size,), time_range[min(i + 1, len(time_range) - 1)], - device=device, + device=self.device, dtype=torch.long, ) if mask is not None: assert orig_latent is not None img_orig = self.model.q_sample(orig_latent, ts) - img = img_orig * mask + (1.0 - mask) * img + noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent - img, pred_x0, noise_prediction = self.p_sample_plms( - img, - conditioning, - ts, + noisy_latent, predicted_latent, noise_pred = self.p_sample_plms( + noisy_latent=noisy_latent, + neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning, + guidance_scale=guidance_scale, + time_encoding=ts, schedule=schedule, index=index, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, ) - old_eps.append(noise_prediction) + old_eps.append(noise_pred) if len(old_eps) >= 4: old_eps.pop(0) - if callback: - callback(i) - if img_callback: - img_callback(img, "img") - img_callback(pred_x0, "pred_x0") - return img + log_latent(noisy_latent, "noisy_latent") + log_latent(predicted_latent, "predicted_latent") + + return noisy_latent @torch.no_grad() def p_sample_plms( self, noisy_latent, + neutral_conditioning, positive_conditioning, + guidance_scale, time_encoding, schedule: PLMSSchedule, index, @@ -180,20 +169,19 @@ class PLMSSampler: quantize_denoised=False, temperature=1.0, noise_dropout=0.0, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, old_eps=None, t_next=None, ): - batch_size = noisy_latent.shape[0] - noise_prediction = get_noise_prediction( + assert guidance_scale >= 1 + noise_pred = get_noise_prediction( denoise_func=self.model.apply_model, noisy_latent=noisy_latent, time_encoding=time_encoding, - neutral_conditioning=unconditional_conditioning, + neutral_conditioning=neutral_conditioning, positive_conditioning=positive_conditioning, - signal_amplification=unconditional_guidance_scale, + signal_amplification=guidance_scale, ) + batch_size = noisy_latent.shape[0] def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep @@ -232,38 +220,33 @@ class PLMSSampler: if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_prediction, index) + x_prev, pred_x0 = get_x_prev_and_pred_x0(noise_pred, index) e_t_next = get_noise_prediction( denoise_func=self.model.apply_model, noisy_latent=x_prev, time_encoding=t_next, - neutral_conditioning=unconditional_conditioning, + neutral_conditioning=neutral_conditioning, positive_conditioning=positive_conditioning, - signal_amplification=unconditional_guidance_scale, + signal_amplification=guidance_scale, ) - e_t_prime = (noise_prediction + e_t_next) / 2 + e_t_prime = (noise_pred + e_t_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * noise_prediction - old_eps[-1]) / 2 + e_t_prime = (3 * noise_pred - old_eps[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = ( - 23 * noise_prediction - 16 * old_eps[-1] + 5 * old_eps[-2] - ) / 12 + e_t_prime = (23 * noise_pred - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = ( - 55 * noise_prediction - - 59 * old_eps[-1] - + 37 * old_eps[-2] - - 9 * old_eps[-3] + 55 * noise_pred - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) log_latent(x_prev, "x_prev") log_latent(pred_x0, "pred_x0") - return x_prev, pred_x0, noise_prediction + return x_prev, pred_x0, noise_pred @torch.no_grad() def noise_an_image(self, init_latent, t, schedule, noise=None): @@ -285,25 +268,22 @@ class PLMSSampler: @torch.no_grad() def decode( self, - cond, + neutral_conditioning, + positive_conditioning, + guidance_scale, schedule, initial_latent=None, t_start=None, - unconditional_guidance_scale=1.0, - unconditional_conditioning=None, - img_callback=None, temperature=1.0, mask=None, orig_latent=None, noise=None, ): - device = self.device timesteps = schedule.ddim_timesteps[:t_start] time_range = np.flip(timesteps) total_steps = timesteps.shape[0] - iterator = tqdm(time_range, desc="PLMS img2img", total=total_steps) x_dec = initial_latent old_eps = [] log_latent(x_dec, "x_dec") @@ -315,7 +295,7 @@ class PLMSSampler: if noise is None else noise ) - for i, step in enumerate(iterator): + for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 ts = torch.full( (initial_latent.shape[0],), @@ -326,7 +306,7 @@ class PLMSSampler: ts_next = torch.full( (initial_latent.shape[0],), time_range[min(i + 1, len(time_range) - 1)], - device=device, + device=self.device, dtype=torch.long, ) @@ -346,13 +326,13 @@ class PLMSSampler: log_latent(x_dec, f"x_dec {ts}") x_dec, pred_x0, noise_prediction = self.p_sample_plms( - x_dec, - cond, - ts, + noisy_latent=x_dec, + guidance_scale=guidance_scale, + neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning, + time_encoding=ts, schedule=schedule, index=index, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, temperature=temperature, old_eps=old_eps, t_next=ts_next, @@ -362,11 +342,6 @@ class PLMSSampler: if len(old_eps) >= 4: old_eps.pop(0) - if img_callback: - img_callback(x_dec, "x_dec") - img_callback(pred_x0, "pred_x0") - log_latent(x_dec, f"x_dec {i}") - log_latent(x_dec, f"e_t {i}") log_latent(pred_x0, f"pred_x0 {i}") return x_dec