mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-19 03:25:41 +00:00
refactor: consolidate masking logic
This commit is contained in:
parent
72026c8c90
commit
e8bb3cf5fd
@ -860,11 +860,9 @@ class LatentDiffusion(DDPM):
|
|||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
def q_sample(self, x_start, t, noise=None):
|
||||||
noise = (
|
if noise is None:
|
||||||
noise
|
noise = torch.randn_like(x_start, device="cpu").to(x_start.device)
|
||||||
if noise is not None
|
|
||||||
else torch.randn_like(x_start, device="cpu").to(x_start.device)
|
|
||||||
)
|
|
||||||
return (
|
return (
|
||||||
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
|
||||||
|
@ -115,6 +115,30 @@ def get_noise_prediction(
|
|||||||
return noise_pred
|
return noise_pred
|
||||||
|
|
||||||
|
|
||||||
|
def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
|
||||||
|
"""
|
||||||
|
Apply a mask to the noisy_latent.
|
||||||
|
|
||||||
|
ts is a decreasing value between 1000 and 1
|
||||||
|
"""
|
||||||
|
assert orig_latent is not None
|
||||||
|
noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)
|
||||||
|
|
||||||
|
# this helps prevent the weird disjointed images that can happen with masking
|
||||||
|
hint_strength = 0.8
|
||||||
|
# if we're in the first 10% of the steps then don't fully noise the parts
|
||||||
|
# of the image we're not changing so that the algorithm can learn from the context
|
||||||
|
if ts > 900:
|
||||||
|
xdec_orig_with_hints = (
|
||||||
|
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
xdec_orig_with_hints = noised_orig_latent
|
||||||
|
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
|
||||||
|
log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
|
||||||
|
return noisy_latent
|
||||||
|
|
||||||
|
|
||||||
def to_torch(x):
|
def to_torch(x):
|
||||||
return x.clone().detach().to(torch.float32).to(get_device())
|
return x.clone().detach().to(torch.float32).to(get_device())
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from imaginairy.log_utils import log_latent
|
from imaginairy.log_utils import log_latent
|
||||||
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
||||||
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
|
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -64,24 +64,25 @@ class DDIMSampler:
|
|||||||
total_steps = timesteps.shape[0]
|
total_steps = timesteps.shape[0]
|
||||||
noisy_latent = initial_latent
|
noisy_latent = initial_latent
|
||||||
|
|
||||||
|
mask_noise = None
|
||||||
|
if mask is not None:
|
||||||
|
mask_noise = torch.randn_like(noisy_latent, device="cpu").to(
|
||||||
|
noisy_latent.device
|
||||||
|
)
|
||||||
|
|
||||||
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
for i, step in enumerate(tqdm(time_range, total=total_steps)):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
|
ts = torch.full((batch_size,), step, device=self.device, dtype=torch.long)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert orig_latent is not None
|
noisy_latent = mask_blend(
|
||||||
xdec_orig = self.model.q_sample(orig_latent, ts)
|
noisy_latent=noisy_latent,
|
||||||
log_latent(xdec_orig, "xdec_orig")
|
orig_latent=orig_latent,
|
||||||
# this helps prevent the weird disjointed images that can happen with masking
|
mask=mask,
|
||||||
hint_strength = 0.8
|
mask_noise=mask_noise,
|
||||||
if i < 2:
|
ts=ts,
|
||||||
xdec_orig_with_hints = (
|
model=self.model,
|
||||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
xdec_orig_with_hints = xdec_orig
|
|
||||||
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
|
|
||||||
log_latent(noisy_latent, "noisy_latent")
|
|
||||||
|
|
||||||
noisy_latent, predicted_latent = self.p_sample_ddim(
|
noisy_latent, predicted_latent = self.p_sample_ddim(
|
||||||
noisy_latent=noisy_latent,
|
noisy_latent=noisy_latent,
|
||||||
|
@ -7,7 +7,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from imaginairy.log_utils import log_latent
|
from imaginairy.log_utils import log_latent
|
||||||
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
||||||
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
|
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
|
||||||
from imaginairy.utils import get_device
|
from imaginairy.utils import get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -86,19 +86,14 @@ class PLMSSampler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert orig_latent is not None
|
noisy_latent = mask_blend(
|
||||||
xdec_orig = self.model.q_sample(orig_latent, ts, mask_noise)
|
noisy_latent=noisy_latent,
|
||||||
log_latent(xdec_orig, f"xdec_orig i={i} index-{index}")
|
orig_latent=orig_latent,
|
||||||
# this helps prevent the weird disjointed images that can happen with masking
|
mask=mask,
|
||||||
hint_strength = 0.8
|
mask_noise=mask_noise,
|
||||||
if i < 2:
|
ts=ts,
|
||||||
xdec_orig_with_hints = (
|
model=self.model,
|
||||||
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
xdec_orig_with_hints = xdec_orig
|
|
||||||
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
|
|
||||||
log_latent(noisy_latent, f"x_dec {ts}")
|
|
||||||
|
|
||||||
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
noisy_latent, predicted_latent, noise_pred = self.p_sample_plms(
|
||||||
noisy_latent=noisy_latent,
|
noisy_latent=noisy_latent,
|
||||||
|
Loading…
Reference in New Issue
Block a user