refactor: consolidate masking logic

This commit is contained in:
Bryce 2022-10-13 01:31:25 -07:00 committed by Bryce Drennan
parent 72026c8c90
commit e8bb3cf5fd
4 changed files with 51 additions and 33 deletions

View File

@ -860,11 +860,9 @@ class LatentDiffusion(DDPM):
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
def q_sample(self, x_start, t, noise=None): def q_sample(self, x_start, t, noise=None):
noise = ( if noise is None:
noise noise = torch.randn_like(x_start, device="cpu").to(x_start.device)
if noise is not None
else torch.randn_like(x_start, device="cpu").to(x_start.device)
)
return ( return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)

View File

@ -115,6 +115,30 @@ def get_noise_prediction(
return noise_pred return noise_pred
def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
"""
Apply a mask to the noisy_latent.
ts is a decreasing value between 1000 and 1
"""
assert orig_latent is not None
noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
# if we're in the first 10% of the steps then don't fully noise the parts
# of the image we're not changing so that the algorithm can learn from the context
if ts > 900:
xdec_orig_with_hints = (
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
)
else:
xdec_orig_with_hints = noised_orig_latent
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
return noisy_latent
def to_torch(x): def to_torch(x):
return x.clone().detach().to(torch.float32).to(get_device()) return x.clone().detach().to(torch.float32).to(get_device())

View File

@ -7,7 +7,7 @@ from tqdm import tqdm
from imaginairy.log_utils import log_latent from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,24 +64,25 @@ class DDIMSampler:
total_steps = timesteps.shape[0] total_steps = timesteps.shape[0]
noisy_latent = initial_latent noisy_latent = initial_latent
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
noisy_latent.device
)
for i, step in enumerate(tqdm(time_range, total=total_steps)): for i, step in enumerate(tqdm(time_range, total=total_steps)):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long) ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
if mask is not None: if mask is not None:
assert orig_latent is not None noisy_latent = mask_blend(
xdec_orig = self.model.q_sample(orig_latent, ts) noisy_latent=noisy_latent,
log_latent(xdec_orig, "xdec_orig") orig_latent=orig_latent,
# this helps prevent the weird disjointed images that can happen with masking mask=mask,
hint_strength = 0.8 mask_noise=mask_noise,
if i < 2: ts=ts,
xdec_orig_with_hints = ( model=self.model,
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
) )
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, "noisy_latent")
noisy_latent, predicted_latent = self.p_sample_ddim( noisy_latent, predicted_latent = self.p_sample_ddim(
noisy_latent=noisy_latent, noisy_latent=noisy_latent,

View File

@ -7,7 +7,7 @@ from tqdm import tqdm
from imaginairy.log_utils import log_latent from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device from imaginairy.utils import get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -86,19 +86,14 @@ class PLMSSampler:
) )
if mask is not None: if mask is not None:
assert orig_latent is not None noisy_latent = mask_blend(
xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise) noisy_latent=noisy_latent,
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}") orig_latent=orig_latent,
# this helps prevent the weird disjointed images that can happen with masking mask=mask,
hint_strength = 0.8 mask_noise=mask_noise,
if i < 2: ts=ts,
xdec_orig_with_hints = ( model=self.model,
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
) )
else:
xdec_orig_with_hints = xdec_orig
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
log_latent(noisy_latent, f"x_dec {ts}")
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms( noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
noisy_latent=noisy_latent, noisy_latent=noisy_latent,