refactor: consolidate masking logic

pull/60/head
Bryce 2 years ago committed by Bryce Drennan
parent 72026c8c90
commit e8bb3cf5fd

@ -860,11 +860,9 @@ class LatentDiffusion(DDPM):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
def q_sample(self, x_start, t, noise=None):
noise = (
noise
if noise is not None
else torch.randn_like(x_start, device="cpu").to(x_start.device)
)
if noise is None:
noise = torch.randn_like(x_start, device="cpu").to(x_start.device)
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)

@ -115,6 +115,30 @@ def get_noise_prediction(
return noise_pred
def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
"""
Apply a mask to the noisy_latent.
ts is a decreasing value between 1000 and 1
"""
assert orig_latent is not None
noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
# if we're in the first 10% of the steps then don't fully noise the parts
# of the image we're not changing so that the algorithm can learn from the context
if ts > 900:
xdec_orig_with_hints = (
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = noised_orig_latent
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
return noisy_latent
def to_torch(x):
return x.clone().detach().to(torch.float32).to(get_device())

@ -7,7 +7,7 @@ from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
@ -64,24 +64,25 @@ class DDIMSampler:
total_steps = timesteps.shape[0]
noisy_latent = initial_latent
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
noisy_latent.device
)
for i, step in enumerate(tqdm(time_range, total=total_steps)):
index = total_steps - i - 1
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, "noisy_latent")
noisy_latent = mask_blend(
noisy_latent=noisy_latent,
orig_latent=orig_latent,
mask=mask,
mask_noise=mask_noise,
ts=ts,
model=self.model,
)
noisy_latent, predicted_latent = self.p_sample_ddim(
noisy_latent=noisy_latent,

@ -7,7 +7,7 @@ from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
@ -86,19 +86,14 @@ class PLMSSampler:
)
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise)
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
if i < 2:
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, f"x_dec {ts}")
noisy_latent = mask_blend(
noisy_latent=noisy_latent,
orig_latent=orig_latent,
mask=mask,
mask_noise=mask_noise,
ts=ts,
model=self.model,
)
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
noisy_latent=noisy_latent,

Loading…
Cancel
Save