2022-09-24 07:29:45 +00:00
|
|
|
# pylama:ignore=W0613
|
2022-09-14 07:40:25 +00:00
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
|
2022-10-11 02:50:11 +00:00
|
|
|
from imaginairy.log_utils import log_latent
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-09-14 16:37:45 +00:00
|
|
|
SAMPLER_TYPE_OPTIONS = [
|
|
|
|
"plms",
|
|
|
|
"ddim",
|
|
|
|
"k_lms",
|
|
|
|
"k_dpm_2",
|
|
|
|
"k_dpm_2_a",
|
|
|
|
"k_euler",
|
|
|
|
"k_euler_a",
|
|
|
|
"k_heun",
|
|
|
|
]
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
_k_sampler_type_lookup = {
|
|
|
|
"k_dpm_2": "dpm_2",
|
|
|
|
"k_dpm_2_a": "dpm_2_ancestral",
|
|
|
|
"k_euler": "euler",
|
|
|
|
"k_euler_a": "euler_ancestral",
|
|
|
|
"k_heun": "heun",
|
|
|
|
"k_lms": "lms",
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def get_sampler(sampler_type, model):
|
2022-09-24 07:29:45 +00:00
|
|
|
from imaginairy.samplers.ddim import DDIMSampler # noqa
|
|
|
|
from imaginairy.samplers.kdiff import KDiffusionSampler # noqa
|
|
|
|
from imaginairy.samplers.plms import PLMSSampler # noqa
|
2022-09-17 19:24:27 +00:00
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
sampler_type = sampler_type.lower()
|
|
|
|
if sampler_type == "plms":
|
|
|
|
return PLMSSampler(model)
|
2022-09-22 17:56:18 +00:00
|
|
|
if sampler_type == "ddim":
|
2022-09-14 07:40:25 +00:00
|
|
|
return DDIMSampler(model)
|
2022-09-22 17:56:18 +00:00
|
|
|
if sampler_type.startswith("k_"):
|
2022-09-14 07:40:25 +00:00
|
|
|
sampler_type = _k_sampler_type_lookup[sampler_type]
|
|
|
|
return KDiffusionSampler(model, sampler_type)
|
2022-09-22 17:56:18 +00:00
|
|
|
raise ValueError("invalid sampler_type")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class CFGDenoiser(nn.Module):
|
2022-09-17 19:24:27 +00:00
|
|
|
"""
|
|
|
|
Conditional forward guidance wrapper
|
|
|
|
"""
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
def __init__(self, model):
|
|
|
|
super().__init__()
|
|
|
|
self.inner_model = model
|
|
|
|
|
2022-09-18 13:07:07 +00:00
|
|
|
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
|
2022-10-06 04:43:00 +00:00
|
|
|
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
|
|
|
|
return self.inner_model(
|
|
|
|
noisy_latent_in, time_encoding_in, cond=conditioning_in
|
|
|
|
)
|
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred = get_noise_prediction(
|
2022-10-06 04:43:00 +00:00
|
|
|
denoise_func=_wrapper,
|
|
|
|
noisy_latent=x,
|
|
|
|
time_encoding=sigma,
|
|
|
|
neutral_conditioning=uncond,
|
|
|
|
positive_conditioning=cond,
|
|
|
|
signal_amplification=cond_scale,
|
|
|
|
)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
|
|
|
if mask is not None:
|
|
|
|
assert orig_latent is not None
|
|
|
|
mask_inv = 1.0 - mask
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred = (orig_latent * mask_inv) + (mask * noise_pred)
|
2022-09-18 13:07:07 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
return noise_pred
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
def ensure_4_dim(t: torch.Tensor):
|
|
|
|
if len(t.shape) == 3:
|
|
|
|
t = t.unsqueeze(dim=0)
|
|
|
|
return t
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
def get_noise_prediction(
|
|
|
|
denoise_func,
|
|
|
|
noisy_latent,
|
|
|
|
time_encoding,
|
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
signal_amplification=7.5,
|
|
|
|
):
|
|
|
|
noisy_latent = ensure_4_dim(noisy_latent)
|
|
|
|
|
|
|
|
noisy_latent_in = torch.cat([noisy_latent] * 2)
|
|
|
|
time_encoding_in = torch.cat([time_encoding] * 2)
|
|
|
|
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
|
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred_neutral, noise_pred_positive = denoise_func(
|
2022-10-06 04:43:00 +00:00
|
|
|
noisy_latent_in, time_encoding_in, conditioning_in
|
|
|
|
).chunk(2)
|
|
|
|
|
|
|
|
amplified_noise_pred = signal_amplification * (
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred_positive - noise_pred_neutral
|
2022-10-06 04:43:00 +00:00
|
|
|
)
|
2022-10-13 05:32:17 +00:00
|
|
|
noise_pred = noise_pred_neutral + amplified_noise_pred
|
2022-10-06 04:43:00 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
log_latent(noise_pred_neutral, "noise_pred_neutral")
|
|
|
|
log_latent(noise_pred_positive, "noise_pred_positive")
|
|
|
|
log_latent(noise_pred, "noise_pred")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
return noise_pred
|