imaginAIry/imaginairy/samplers/base.py

110 lines
2.9 KiB
Python
Raw Normal View History

2022-09-24 07:29:45 +00:00
# pylama:ignore=W0613
import torch
from torch import nn
2022-10-11 02:50:11 +00:00
from imaginairy.log_utils import log_latent
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_euler",
"k_euler_a",
"k_heun",
]
_k_sampler_type_lookup = {
"k_dpm_2": "dpm_2",
"k_dpm_2_a": "dpm_2_ancestral",
"k_euler": "euler",
"k_euler_a": "euler_ancestral",
"k_heun": "heun",
"k_lms": "lms",
}
def get_sampler(sampler_type, model):
2022-09-24 07:29:45 +00:00
from imaginairy.samplers.ddim import DDIMSampler # noqa
from imaginairy.samplers.kdiff import KDiffusionSampler # noqa
from imaginairy.samplers.plms import PLMSSampler # noqa
2022-09-17 19:24:27 +00:00
sampler_type = sampler_type.lower()
if sampler_type == "plms":
return PLMSSampler(model)
if sampler_type == "ddim":
return DDIMSampler(model)
if sampler_type.startswith("k_"):
sampler_type = _k_sampler_type_lookup[sampler_type]
return KDiffusionSampler(model, sampler_type)
raise ValueError("invalid sampler_type")
class CFGDenoiser(nn.Module):
2022-09-17 19:24:27 +00:00
"""
Conditional forward guidance wrapper
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
return self.inner_model(
noisy_latent_in, time_encoding_in, cond=conditioning_in
)
2022-10-13 05:32:17 +00:00
noise_pred = get_noise_prediction(
denoise_func=_wrapper,
noisy_latent=x,
time_encoding=sigma,
neutral_conditioning=uncond,
positive_conditioning=cond,
signal_amplification=cond_scale,
)
if mask is not None:
assert orig_latent is not None
mask_inv = 1.0 - mask
2022-10-13 05:32:17 +00:00
noise_pred = (orig_latent * mask_inv) + (mask * noise_pred)
2022-10-13 05:32:17 +00:00
return noise_pred
def ensure_4_dim(t: torch.Tensor):
if len(t.shape) == 3:
t = t.unsqueeze(dim=0)
return t
def get_noise_prediction(
denoise_func,
noisy_latent,
time_encoding,
neutral_conditioning,
positive_conditioning,
signal_amplification=7.5,
):
noisy_latent = ensure_4_dim(noisy_latent)
noisy_latent_in = torch.cat([noisy_latent] * 2)
time_encoding_in = torch.cat([time_encoding] * 2)
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
2022-10-13 05:32:17 +00:00
noise_pred_neutral, noise_pred_positive = denoise_func(
noisy_latent_in, time_encoding_in, conditioning_in
).chunk(2)
amplified_noise_pred = signal_amplification * (
2022-10-13 05:32:17 +00:00
noise_pred_positive - noise_pred_neutral
)
2022-10-13 05:32:17 +00:00
noise_pred = noise_pred_neutral + amplified_noise_pred
2022-10-13 05:32:17 +00:00
log_latent(noise_pred_neutral, "noise_pred_neutral")
log_latent(noise_pred_positive, "noise_pred_positive")
log_latent(noise_pred, "noise_pred")
2022-10-13 05:32:17 +00:00
return noise_pred