diff --git a/imaginairy/api.py b/imaginairy/api.py index 675f4e4..c5fab28 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -186,20 +186,20 @@ def imagine( seed_everything(prompt.seed) model.tile_mode(prompt.tile_mode) - uc = None + neutral_conditioning = None if prompt.prompt_strength != 1.0: - uc = model.get_learned_conditioning(1 * [""]) - log_conditioning(uc, "neutral conditioning") + neutral_conditioning = model.get_learned_conditioning(1 * [""]) + log_conditioning(neutral_conditioning, "neutral conditioning") if prompt.conditioning is not None: - c = prompt.conditioning + positive_conditioning = prompt.conditioning else: total_weight = sum(wp.weight for wp in prompt.prompts) - c = sum( + positive_conditioning = sum( model.get_learned_conditioning(wp.text) * (wp.weight / total_weight) for wp in prompt.prompts ) - log_conditioning(c, "positive conditioning") + log_conditioning(positive_conditioning, "positive conditioning") shape = [ 1, @@ -209,7 +209,7 @@ def imagine( ] if prompt.init_image and prompt.sampler_type not in ("ddim", "plms"): sampler_type = "plms" - logger.info(" Sampler type switched to plms for img2img") + logger.info("Sampler type switched to plms for img2img") else: sampler_type = prompt.sampler_type @@ -287,36 +287,36 @@ def imagine( # prompt strength gets converted to time encodings, # which means you can't get to true 0 without this hack # (or setting steps=1000) - z_enc = noise + init_latent_noised = noise else: - z_enc = sampler.noise_an_image( + init_latent_noised = sampler.noise_an_image( init_latent, torch.tensor([t_enc - 1]).to(get_device()), schedule=schedule, noise=noise, ) - log_latent(z_enc, "z_enc") + log_latent(init_latent_noised, "init_latent_noised") - # decode it - samples = sampler.decode( - initial_latent=z_enc, - positive_conditioning=c, - t_start=t_enc, - schedule=schedule, + samples = sampler.sample( + num_steps=prompt.steps, + initial_latent=init_latent_noised, + positive_conditioning=positive_conditioning, + neutral_conditioning=neutral_conditioning, guidance_scale=prompt.prompt_strength, - neutral_conditioning=uc, + t_start=t_enc, mask=mask, orig_latent=init_latent, + shape=shape, + batch_size=1, ) else: - samples = sampler.sample( num_steps=prompt.steps, - positive_conditioning=c, + neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning, + guidance_scale=prompt.prompt_strength, batch_size=1, shape=shape, - guidance_scale=prompt.prompt_strength, - neutral_conditioning=uc, ) x_samples = model.decode_first_stage(samples) diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 8b0d25f..a7e73df 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -38,6 +38,7 @@ class DDIMSampler: temperature=1.0, noise_dropout=0.0, initial_latent=None, + t_start=None, quantize_x0=False, ): if positive_conditioning.shape[0] != batch_size: @@ -57,7 +58,7 @@ class DDIMSampler: log_latent(initial_latent, "initial latent") - timesteps = schedule.ddim_timesteps + timesteps = schedule.ddim_timesteps[:t_start] time_range = np.flip(timesteps) total_steps = timesteps.shape[0] @@ -69,8 +70,18 @@ class DDIMSampler: if mask is not None: assert orig_latent is not None - img_orig = self.model.q_sample(orig_latent, ts) - noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent + xdec_orig = self.model.q_sample(orig_latent, ts) + log_latent(xdec_orig, "xdec_orig") + # this helps prevent the weird disjointed images that can happen with masking + hint_strength = 0.8 + if i < 2: + xdec_orig_with_hints = ( + xdec_orig * (1 - hint_strength) + orig_latent * hint_strength + ) + else: + xdec_orig_with_hints = xdec_orig + noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent + log_latent(noisy_latent, "noisy_latent") noisy_latent, predicted_latent = self.p_sample_ddim( noisy_latent=noisy_latent, @@ -190,63 +201,3 @@ class DDIMSampler: + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape) * noise ) - - @torch.no_grad() - def decode( - self, - initial_latent, - neutral_conditioning, - positive_conditioning, - guidance_scale, - t_start, - schedule, - temperature=1.0, - mask=None, - orig_latent=None, - ): - - timesteps = schedule.ddim_timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - noisy_latent = initial_latent - - for i, step in enumerate(tqdm(time_range, total=total_steps)): - index = total_steps - i - 1 - ts = torch.full( - (initial_latent.shape[0],), - step, - device=initial_latent.device, - dtype=torch.long, - ) - - if mask is not None: - assert orig_latent is not None - xdec_orig = self.model.q_sample(orig_latent, ts) - log_latent(xdec_orig, "xdec_orig") - # this helps prevent the weird disjointed images that can happen with masking - hint_strength = 0.8 - if i < 2: - xdec_orig_with_hints = ( - xdec_orig * (1 - hint_strength) + orig_latent * hint_strength - ) - else: - xdec_orig_with_hints = xdec_orig - noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent - log_latent(noisy_latent, "noisy_latent") - - noisy_latent, predicted_latent = self.p_sample_ddim( - noisy_latent=noisy_latent, - positive_conditioning=positive_conditioning, - time_encoding=ts, - schedule=schedule, - index=index, - guidance_scale=guidance_scale, - neutral_conditioning=neutral_conditioning, - temperature=temperature, - ) - - log_latent(noisy_latent, f"noisy_latent {i}") - log_latent(predicted_latent, f"predicted_latent {i}") - return noisy_latent diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index 4bc4332..9cc24b6 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -41,6 +41,7 @@ class PLMSSampler: temperature=1.0, noise_dropout=0.0, initial_latent=None, + t_start=None, quantize_denoised=False, **kwargs, ): @@ -61,13 +62,18 @@ class PLMSSampler: log_latent(initial_latent, "initial latent") - timesteps = schedule.ddim_timesteps + timesteps = schedule.ddim_timesteps[:t_start] time_range = np.flip(timesteps) total_steps = timesteps.shape[0] old_eps = [] noisy_latent = initial_latent + mask_noise = None + if mask is not None: + mask_noise = torch.randn_like(noisy_latent, device="cpu").to( + noisy_latent.device + ) for i, step in enumerate(tqdm(time_range, total=total_steps)): index = total_steps - i - 1 @@ -81,8 +87,18 @@ class PLMSSampler: if mask is not None: assert orig_latent is not None - img_orig = self.model.q_sample(orig_latent, ts) - noisy_latent = img_orig * mask + (1.0 - mask) * noisy_latent + xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise) + log_latent(xdec_orig, f"xdec_orig i={i} index-{index}") + # this helps prevent the weird disjointed images that can happen with masking + hint_strength = 0.8 + if i < 2: + xdec_orig_with_hints = ( + xdec_orig * (1 - hint_strength) + orig_latent * hint_strength + ) + else: + xdec_orig_with_hints = xdec_orig + noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent + log_latent(noisy_latent, f"x_dec {ts}") noisy_latent, predicted_latent, noise_pred = self.p_sample_plms( noisy_latent=noisy_latent, @@ -202,7 +218,6 @@ class PLMSSampler: @torch.no_grad() def noise_an_image(self, init_latent, t, schedule, noise=None): - # replace with ddpm.q_sample? # fast, but does not allow for exact reconstruction # t serves as an index to gather the correct alphas t = t.clamp(0, 1000) @@ -216,84 +231,3 @@ class PLMSSampler: + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape) * noise ) - - @torch.no_grad() - def decode( - self, - neutral_conditioning, - positive_conditioning, - guidance_scale, - schedule, - initial_latent=None, - t_start=None, - temperature=1.0, - mask=None, - orig_latent=None, - noise=None, - ): - timesteps = schedule.ddim_timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - - x_dec = initial_latent - old_eps = [] - log_latent(x_dec, "x_dec") - - # not sure what the downside of using the same noise throughout the process would be... - # seems to work fine. maybe it runs faster? - noise = ( - torch.randn_like(x_dec, device="cpu").to(x_dec.device) - if noise is None - else noise - ) - for i, step in enumerate(tqdm(time_range, total=total_steps)): - index = total_steps - i - 1 - ts = torch.full( - (initial_latent.shape[0],), - step, - device=initial_latent.device, - dtype=torch.long, - ) - ts_next = torch.full( - (initial_latent.shape[0],), - time_range[min(i + 1, len(time_range) - 1)], - device=self.device, - dtype=torch.long, - ) - - if mask is not None: - assert orig_latent is not None - xdec_orig = self.model.q_sample(orig_latent, ts, noise) - log_latent(xdec_orig, f"xdec_orig i={i} index-{index}") - # this helps prevent the weird disjointed images that can happen with masking - hint_strength = 0.8 - if i < 2: - xdec_orig_with_hints = ( - xdec_orig * (1 - hint_strength) + orig_latent * hint_strength - ) - else: - xdec_orig_with_hints = xdec_orig - x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec - log_latent(x_dec, f"x_dec {ts}") - - x_dec, pred_x0, noise_prediction = self.p_sample_plms( - noisy_latent=x_dec, - guidance_scale=guidance_scale, - neutral_conditioning=neutral_conditioning, - positive_conditioning=positive_conditioning, - time_encoding=ts, - schedule=schedule, - index=index, - temperature=temperature, - old_eps=old_eps, - t_next=ts_next, - ) - - old_eps.append(noise_prediction) - if len(old_eps) >= 4: - old_eps.pop(0) - - log_latent(x_dec, f"x_dec {i}") - log_latent(pred_x0, f"pred_x0 {i}") - return x_dec