2022-09-24 07:29:45 +00:00
|
|
|
# pylama:ignore=W0613
|
2022-09-14 07:40:25 +00:00
|
|
|
import torch
|
|
|
|
|
2022-10-11 02:50:11 +00:00
|
|
|
from imaginairy.log_utils import log_latent
|
2022-09-17 19:24:27 +00:00
|
|
|
from imaginairy.samplers.base import CFGDenoiser
|
2022-09-14 07:40:25 +00:00
|
|
|
from imaginairy.utils import get_device
|
|
|
|
from imaginairy.vendored.k_diffusion import sampling as k_sampling
|
|
|
|
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
|
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
class StandardCompVisDenoiser(CompVisDenoiser):
|
|
|
|
def apply_model(self, *args, **kwargs):
|
|
|
|
return self.inner_model.apply_model(*args, **kwargs)
|
|
|
|
|
|
|
|
|
2022-10-15 00:21:38 +00:00
|
|
|
def sample_dpm_adaptive(
|
|
|
|
model, x, sigmas, extra_args=None, disable=False, callback=None
|
|
|
|
):
|
|
|
|
sigma_min = sigmas[-2]
|
|
|
|
sigma_max = sigmas[0]
|
|
|
|
return k_sampling.sample_dpm_adaptive(
|
|
|
|
model=model,
|
|
|
|
x=x,
|
|
|
|
sigma_min=sigma_min,
|
|
|
|
sigma_max=sigma_max,
|
|
|
|
extra_args=extra_args,
|
|
|
|
disable=disable,
|
|
|
|
callback=callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=None):
|
|
|
|
sigma_min = sigmas[-2]
|
|
|
|
sigma_max = sigmas[0]
|
|
|
|
return k_sampling.sample_dpm_fast(
|
|
|
|
model=model,
|
|
|
|
x=x,
|
|
|
|
sigma_min=sigma_min,
|
|
|
|
sigma_max=sigma_max,
|
|
|
|
n=len(sigmas),
|
|
|
|
extra_args=extra_args,
|
|
|
|
disable=disable,
|
|
|
|
callback=callback,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
class KDiffusionSampler:
|
2022-10-15 00:21:38 +00:00
|
|
|
|
|
|
|
sampler_lookup = {
|
|
|
|
"dpm_fast": sample_dpm_fast,
|
|
|
|
"dpm_adaptive": sample_dpm_adaptive,
|
|
|
|
"dpm_2": k_sampling.sample_dpm_2,
|
|
|
|
"dpm_2_ancestral": k_sampling.sample_dpm_2_ancestral,
|
|
|
|
"euler": k_sampling.sample_euler,
|
|
|
|
"euler_ancestral": k_sampling.sample_euler_ancestral,
|
|
|
|
"heun": k_sampling.sample_heun,
|
|
|
|
"lms": k_sampling.sample_lms,
|
|
|
|
}
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
def __init__(self, model, sampler_name):
|
|
|
|
self.model = model
|
2022-10-06 04:43:00 +00:00
|
|
|
self.cv_denoiser = StandardCompVisDenoiser(model)
|
2022-09-14 07:40:25 +00:00
|
|
|
self.sampler_name = sampler_name
|
2022-10-15 00:21:38 +00:00
|
|
|
self.sampler_func = self.sampler_lookup[sampler_name]
|
2022-10-13 05:32:17 +00:00
|
|
|
self.device = get_device()
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
num_steps,
|
|
|
|
shape,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
guidance_scale,
|
|
|
|
batch_size=1,
|
|
|
|
mask=None,
|
|
|
|
orig_latent=None,
|
|
|
|
initial_latent=None,
|
2022-10-13 09:01:54 +00:00
|
|
|
t_start=None,
|
2022-09-14 07:40:25 +00:00
|
|
|
):
|
2022-10-13 05:32:17 +00:00
|
|
|
if positive_conditioning.shape[0] != batch_size:
|
|
|
|
raise ValueError(
|
|
|
|
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if initial_latent is None:
|
|
|
|
initial_latent = torch.randn(shape, device="cpu").to(self.device)
|
|
|
|
|
|
|
|
log_latent(initial_latent, "initial_latent")
|
2022-10-14 03:49:48 +00:00
|
|
|
if t_start is not None:
|
|
|
|
t_start = num_steps - t_start + 1
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]
|
|
|
|
|
|
|
|
# if our number of steps is zero, just return the initial latent
|
|
|
|
if sigmas.nelement() == 0:
|
|
|
|
if orig_latent is not None:
|
|
|
|
return orig_latent
|
|
|
|
return initial_latent
|
2022-09-14 07:40:25 +00:00
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
x = initial_latent * sigmas[0]
|
2022-09-14 07:40:25 +00:00
|
|
|
log_latent(x, "initial_sigma_noised_tensor")
|
|
|
|
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
|
|
|
|
|
2022-10-14 03:49:48 +00:00
|
|
|
mask_noise = None
|
|
|
|
if mask is not None:
|
|
|
|
mask_noise = torch.randn_like(initial_latent, device="cpu").to(
|
|
|
|
initial_latent.device
|
|
|
|
)
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
def callback(data):
|
2022-10-13 05:32:17 +00:00
|
|
|
log_latent(data["x"], "noisy_latent")
|
2022-10-14 03:49:48 +00:00
|
|
|
log_latent(data["denoised"], "noise_pred c")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
samples = self.sampler_func(
|
2022-10-13 05:32:17 +00:00
|
|
|
model=model_wrap_cfg,
|
|
|
|
x=x,
|
|
|
|
sigmas=sigmas,
|
2022-09-14 07:40:25 +00:00
|
|
|
extra_args={
|
2022-10-13 05:32:17 +00:00
|
|
|
"cond": positive_conditioning,
|
|
|
|
"uncond": neutral_conditioning,
|
|
|
|
"cond_scale": guidance_scale,
|
2022-10-13 09:01:54 +00:00
|
|
|
"mask": mask,
|
2022-10-14 03:49:48 +00:00
|
|
|
"mask_noise": mask_noise,
|
2022-10-13 09:01:54 +00:00
|
|
|
"orig_latent": orig_latent,
|
2022-09-14 07:40:25 +00:00
|
|
|
},
|
|
|
|
disable=False,
|
|
|
|
callback=callback,
|
|
|
|
)
|
|
|
|
|
2022-09-17 19:24:27 +00:00
|
|
|
return samples
|