You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/samplers/kdiff.py

261 lines
7.4 KiB
Python

# pylama:ignore=W0613
from abc import ABC
import torch
from torch import nn
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.samplers.base import (
ImageSampler,
SamplerName,
get_noise_prediction,
mask_blend,
)
from imaginairy.utils import get_device
from imaginairy.vendored.k_diffusion import sampling as k_sampling
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
class StandardCompVisDenoiser(CompVisDenoiser):
def apply_model(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class StandardCompVisVDenoiser(CompVisVDenoiser):
def apply_model(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
def sample_dpm_adaptive(
model, x, sigmas, extra_args=None, disable=False, callback=None
):
sigma_min = sigmas[-2]
sigma_max = sigmas[0]
return k_sampling.sample_dpm_adaptive(
model=model,
x=x,
sigma_min=sigma_min,
sigma_max=sigma_max,
extra_args=extra_args,
disable=disable,
callback=callback,
)
def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=None):
sigma_min = sigmas[-2]
sigma_max = sigmas[0]
return k_sampling.sample_dpm_fast(
model=model,
x=x,
sigma_min=sigma_min,
sigma_max=sigma_max,
n=len(sigmas),
extra_args=extra_args,
disable=disable,
callback=callback,
)
class KDiffusionSampler(ImageSampler, ABC):
sampler_func: callable
def __init__(self, model):
super().__init__(model)
denoiseer_cls = (
StandardCompVisVDenoiser
if model.parameterization == "v"
else StandardCompVisDenoiser
)
self.cv_denoiser = denoiseer_cls(model)
def sample(
self,
num_steps,
shape,
neutral_conditioning,
positive_conditioning,
guidance_scale,
batch_size=1,
mask=None,
orig_latent=None,
initial_latent=None,
t_start=None,
denoiser_cls=None,
):
# if positive_conditioning.shape[0] != batch_size:
# raise ValueError(
# f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
# )
if initial_latent is None:
initial_latent = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial_latent")
if t_start is not None:
t_start = num_steps - t_start + 1
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]
# if our number of steps is zero, just return the initial latent
if sigmas.nelement() == 0:
if orig_latent is not None:
return orig_latent
return initial_latent
x = initial_latent * sigmas[0]
log_latent(x, "initial_sigma_noised_tensor")
if denoiser_cls is None:
denoiser_cls = CFGDenoiser
model_wrap_cfg = denoiser_cls(self.cv_denoiser)
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(initial_latent, device="cpu").to(
initial_latent.device
)
def callback(data):
log_latent(data["x"], "noisy_latent")
log_latent(data["denoised"], "predicted_latent")
increment_step()
samples = self.sampler_func(
model=model_wrap_cfg,
x=x,
sigmas=sigmas,
extra_args={
"cond": positive_conditioning,
"uncond": neutral_conditioning,
"cond_scale": guidance_scale,
"mask": mask,
"mask_noise": mask_noise,
"orig_latent": orig_latent,
},
disable=False,
callback=callback,
)
return samples
class DPMFastSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_FAST
name = "Diffusion probabilistic models - fast"
default_steps = 15
sampler_func = staticmethod(sample_dpm_fast)
class DPMAdaptiveSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_ADAPTIVE
name = "Diffusion probabilistic models - adaptive"
default_steps = 40
sampler_func = staticmethod(sample_dpm_adaptive)
class DPM2Sampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_2
name = "Diffusion probabilistic models - 2"
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_dpm_2)
class DPM2AncestralSampler(KDiffusionSampler):
short_name = SamplerName.K_DPM_2_ANCESTRAL
name = "Diffusion probabilistic models - 2 ancestral"
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral)
class DPMPP2MSampler(KDiffusionSampler):
short_name = SamplerName.K_DPMPP_2M
name = "Diffusion probabilistic models - 2m"
default_steps = 15
sampler_func = staticmethod(k_sampling.sample_dpmpp_2m)
class DPMPP2SAncestralSampler(KDiffusionSampler):
short_name = SamplerName.K_DPMPP_2S_ANCESTRAL
name = "Ancestral sampling with DPM-Solver++(2S) second-order steps."
default_steps = 15
sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral)
class EulerSampler(KDiffusionSampler):
short_name = SamplerName.K_EULER
name = "Algorithm 2 (Euler steps) from Karras et al. (2022)"
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_euler)
class EulerAncestralSampler(KDiffusionSampler):
short_name = SamplerName.K_EULER_ANCESTRAL
name = "Euler ancestral"
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_euler_ancestral)
class HeunSampler(KDiffusionSampler):
short_name = SamplerName.K_HEUN
name = "Algorithm 2 (Heun steps) from Karras et al. (2022)."
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_heun)
class LMSSampler(KDiffusionSampler):
short_name = SamplerName.K_LMS
name = "LMS"
default_steps = 40
sampler_func = staticmethod(k_sampling.sample_lms)
class CFGDenoiser(nn.Module):
"""
Conditional forward guidance wrapper.
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
self.device = get_device()
def forward(
self,
x,
sigma,
uncond,
cond,
cond_scale,
mask=None,
mask_noise=None,
orig_latent=None,
):
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
return self.inner_model(
noisy_latent_in, time_encoding_in, cond=conditioning_in
)
if mask is not None:
assert orig_latent is not None
t = self.inner_model.sigma_to_t(sigma, quantize=True)
big_sigma = max(sigma, 1)
x = mask_blend(
noisy_latent=x,
orig_latent=orig_latent * big_sigma,
mask=mask,
mask_noise=mask_noise * big_sigma,
ts=t,
model=self.inner_model.inner_model,
)
noise_pred = get_noise_prediction(
denoise_func=_wrapper,
noisy_latent=x,
time_encoding=sigma,
neutral_conditioning=uncond,
positive_conditioning=cond,
signal_amplification=cond_scale,
)
return noise_pred