2022-09-24 07:29:45 +00:00
|
|
|
# pylama:ignore=W0613
|
2022-09-14 07:40:25 +00:00
|
|
|
import torch
|
|
|
|
|
2022-10-11 02:50:11 +00:00
|
|
|
from imaginairy.log_utils import log_latent
|
2022-09-17 19:24:27 +00:00
|
|
|
from imaginairy.samplers.base import CFGDenoiser
|
2022-09-14 07:40:25 +00:00
|
|
|
from imaginairy.utils import get_device
|
|
|
|
from imaginairy.vendored.k_diffusion import sampling as k_sampling
|
|
|
|
from imaginairy.vendored.k_diffusion.external import CompVisDenoiser
|
|
|
|
|
|
|
|
|
2022-10-06 04:43:00 +00:00
|
|
|
class StandardCompVisDenoiser(CompVisDenoiser):
|
|
|
|
def apply_model(self, *args, **kwargs):
|
|
|
|
return self.inner_model.apply_model(*args, **kwargs)
|
|
|
|
|
|
|
|
|
2022-09-14 07:40:25 +00:00
|
|
|
class KDiffusionSampler:
|
|
|
|
def __init__(self, model, sampler_name):
|
|
|
|
self.model = model
|
2022-10-06 04:43:00 +00:00
|
|
|
self.cv_denoiser = StandardCompVisDenoiser(model)
|
2022-09-14 07:40:25 +00:00
|
|
|
self.sampler_name = sampler_name
|
|
|
|
self.sampler_func = getattr(k_sampling, f"sample_{sampler_name}")
|
2022-10-13 05:32:17 +00:00
|
|
|
self.device = get_device()
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
def sample(
|
|
|
|
self,
|
|
|
|
num_steps,
|
|
|
|
shape,
|
2022-10-13 05:32:17 +00:00
|
|
|
neutral_conditioning,
|
|
|
|
positive_conditioning,
|
|
|
|
guidance_scale,
|
|
|
|
batch_size=1,
|
|
|
|
mask=None,
|
|
|
|
orig_latent=None,
|
|
|
|
initial_latent=None,
|
2022-09-14 07:40:25 +00:00
|
|
|
img_callback=None,
|
|
|
|
):
|
2022-10-13 05:32:17 +00:00
|
|
|
if positive_conditioning.shape[0] != batch_size:
|
|
|
|
raise ValueError(
|
|
|
|
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if initial_latent is None:
|
|
|
|
initial_latent = torch.randn(shape, device="cpu").to(self.device)
|
|
|
|
|
|
|
|
log_latent(initial_latent, "initial_latent")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
sigmas = self.cv_denoiser.get_sigmas(num_steps)
|
|
|
|
|
2022-10-13 05:32:17 +00:00
|
|
|
x = initial_latent * sigmas[0]
|
2022-09-14 07:40:25 +00:00
|
|
|
log_latent(x, "initial_sigma_noised_tensor")
|
|
|
|
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
|
|
|
|
|
|
|
|
def callback(data):
|
2022-10-13 05:32:17 +00:00
|
|
|
log_latent(data["x"], "noisy_latent")
|
|
|
|
log_latent(data["denoised"], "noise_pred")
|
2022-09-14 07:40:25 +00:00
|
|
|
|
|
|
|
samples = self.sampler_func(
|
2022-10-13 05:32:17 +00:00
|
|
|
model=model_wrap_cfg,
|
|
|
|
x=x,
|
|
|
|
sigmas=sigmas,
|
2022-09-14 07:40:25 +00:00
|
|
|
extra_args={
|
2022-10-13 05:32:17 +00:00
|
|
|
"cond": positive_conditioning,
|
|
|
|
"uncond": neutral_conditioning,
|
|
|
|
"cond_scale": guidance_scale,
|
2022-09-14 07:40:25 +00:00
|
|
|
},
|
|
|
|
disable=False,
|
|
|
|
callback=callback,
|
|
|
|
)
|
|
|
|
|
2022-09-17 19:24:27 +00:00
|
|
|
return samples
|