|
|
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from imaginairy.log_utils import log_latent
|
|
|
|
|
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
|
|
|
|
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
|
|
|
|
|
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@ -64,24 +64,25 @@ class DDIMSampler:
|
|
|
|
|
total_steps = timesteps.shape[0]
|
|
|
|
|
noisy_latent = initial_latent
|
|
|
|
|
|
|
|
|
|
mask_noise = None
|
|
|
|
|
if mask is not None:
|
|
|
|
|
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
|
|
|
|
|
noisy_latent.device
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
|
|
|
|
index = total_steps - i - 1
|
|
|
|
|
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
|
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
|
assert orig_latent is not None
|
|
|
|
|
xdec_orig = self.model.q_sample(orig_latent, ts)
|
|
|
|
|
log_latent(xdec_orig, "xdec_orig")
|
|
|
|
|
# this helps prevent the weird disjointed images that can happen with masking
|
|
|
|
|
hint_strength = 0.8
|
|
|
|
|
if i < 2:
|
|
|
|
|
xdec_orig_with_hints = (
|
|
|
|
|
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
xdec_orig_with_hints = xdec_orig
|
|
|
|
|
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
|
|
|
|
|
log_latent(noisy_latent, "noisy_latent")
|
|
|
|
|
noisy_latent = mask_blend(
|
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
|
orig_latent=orig_latent,
|
|
|
|
|
mask=mask,
|
|
|
|
|
mask_noise=mask_noise,
|
|
|
|
|
ts=ts,
|
|
|
|
|
model=self.model,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
noisy_latent, predicted_latent = self.p_sample_ddim(
|
|
|
|
|
noisy_latent=noisy_latent,
|
|
|
|
|