diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index d6c8c69..4b0223b 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -860,11 +860,9 @@ class LatentDiffusion(DDPM): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise def q_sample(self, x_start, t, noise=None): - noise = ( - noise - if noise is not None - else torch.randn_like(x_start, device="cpu").to(x_start.device) - ) + if noise is None: + noise = torch.randn_like(x_start, device="cpu").to(x_start.device) + return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 89c253c..07284c6 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -115,6 +115,30 @@ def get_noise_prediction( return noise_pred +def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model): + """ + Apply a mask to the noisy_latent. + + ts is a decreasing value between 1000 and 1 + """ + assert orig_latent is not None + noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise) + + # this helps prevent the weird disjointed images that can happen with masking + hint_strength = 0.8 + # if we're in the first 10% of the steps then don't fully noise the parts + # of the image we're not changing so that the algorithm can learn from the context + if ts > 900: + xdec_orig_with_hints = ( + noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength + ) + else: + xdec_orig_with_hints = noised_orig_latent + noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent + log_latent(noisy_latent, f"mask-blended noisy_latent {ts}") + return noisy_latent + + def to_torch(x): return x.clone().detach().to(torch.float32).to(get_device()) diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index a7e73df..0086600 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -7,7 +7,7 @@ from tqdm import tqdm from imaginairy.log_utils import log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like -from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction +from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend from imaginairy.utils import get_device logger = logging.getLogger(__name__) @@ -64,24 +64,25 @@ class DDIMSampler: total_steps = timesteps.shape[0] noisy_latent = initial_latent + mask_noise = None + if mask is not None: + mask_noise = torch.randn_like(noisy_latent, device="cpu").to( + noisy_latent.device + ) + for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long) if mask is not None: - assert orig_latent is not None - xdec_orig = self.model.q_sample(orig_latent, ts) - log_latent(xdec_orig, "xdec_orig") - # this helps prevent the weird disjointed images that can happen with masking - hint_strength = 0.8 - if i < 2: - xdec_orig_with_hints = ( - xdec_orig * (1 - hint_strength) + orig_latent * hint_strength - ) - else: - xdec_orig_with_hints = xdec_orig - noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent - log_latent(noisy_latent, "noisy_latent") + noisy_latent = mask_blend( + noisy_latent=noisy_latent, + orig_latent=orig_latent, + mask=mask, + mask_noise=mask_noise, + ts=ts, + model=self.model, + ) noisy_latent, predicted_latent = self.p_sample_ddim( noisy_latent=noisy_latent, diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index 9cc24b6..0fd4838 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -7,7 +7,7 @@ from tqdm import tqdm from imaginairy.log_utils import log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like -from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction +from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend from imaginairy.utils import get_device logger = logging.getLogger(__name__) @@ -86,19 +86,14 @@ class PLMSSampler: ) if mask is not None: - assert orig_latent is not None - xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise) - log_latent(xdec_orig, f"xdec_orig i={i} index-{index}") - # this helps prevent the weird disjointed images that can happen with masking - hint_strength = 0.8 - if i < 2: - xdec_orig_with_hints = ( - xdec_orig * (1 - hint_strength) + orig_latent * hint_strength - ) - else: - xdec_orig_with_hints = xdec_orig - noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent - log_latent(noisy_latent, f"x_dec {ts}") + noisy_latent = mask_blend( + noisy_latent=noisy_latent, + orig_latent=orig_latent, + mask=mask, + mask_noise=mask_noise, + ts=ts, + model=self.model, + ) noisy_latent, predicted_latent, noise_pred = self.p_sample_plms( noisy_latent=noisy_latent,