|
|
|
@ -6,61 +6,13 @@ import torch
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
from imaginairy.log_utils import log_latent
|
|
|
|
|
from imaginairy.modules.diffusion.util import (
|
|
|
|
|
extract_into_tensor,
|
|
|
|
|
make_ddim_sampling_parameters,
|
|
|
|
|
make_ddim_timesteps,
|
|
|
|
|
noise_like,
|
|
|
|
|
)
|
|
|
|
|
from imaginairy.samplers.base import get_noise_prediction
|
|
|
|
|
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
|
|
|
|
|
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_torch(x):
|
|
|
|
|
return x.clone().detach().to(torch.float32).to(get_device())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PLMSSchedule:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
model_num_timesteps, # 1000?
|
|
|
|
|
model_alphas_cumprod,
|
|
|
|
|
ddim_num_steps, # prompt.steps?
|
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
|
):
|
|
|
|
|
device = get_device()
|
|
|
|
|
if model_alphas_cumprod.shape[0] != model_num_timesteps:
|
|
|
|
|
raise ValueError("alphas have to be defined for each timestep")
|
|
|
|
|
|
|
|
|
|
self.alphas_cumprod = to_torch(model_alphas_cumprod)
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
|
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
|
np.sqrt(1.0 - model_alphas_cumprod.cpu())
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
|
|
num_ddim_timesteps=ddim_num_steps,
|
|
|
|
|
num_ddpm_timesteps=model_num_timesteps,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# ddim sampling parameters
|
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
|
alphacums=model_alphas_cumprod.cpu(),
|
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
|
eta=0.0,
|
|
|
|
|
)
|
|
|
|
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
|
|
|
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
|
|
|
|
self.ddim_alphas_prev = ddim_alphas_prev
|
|
|
|
|
self.ddim_sqrt_one_minus_alphas = (
|
|
|
|
|
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(torch.device(device))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PLMSSampler:
|
|
|
|
|
"""
|
|
|
|
|
probabilistic least-mean-squares
|
|
|
|
@ -97,7 +49,7 @@ class PLMSSampler:
|
|
|
|
|
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
schedule = PLMSSchedule(
|
|
|
|
|
schedule = NoiseSchedule(
|
|
|
|
|
model_num_timesteps=self.model.num_timesteps,
|
|
|
|
|
ddim_num_steps=num_steps,
|
|
|
|
|
model_alphas_cumprod=self.model.alphas_cumprod,
|
|
|
|
@ -163,7 +115,7 @@ class PLMSSampler:
|
|
|
|
|
positive_conditioning,
|
|
|
|
|
guidance_scale,
|
|
|
|
|
time_encoding,
|
|
|
|
|
schedule: PLMSSchedule,
|
|
|
|
|
schedule: NoiseSchedule,
|
|
|
|
|
index,
|
|
|
|
|
repeat_noise=False,
|
|
|
|
|
quantize_denoised=False,
|
|
|
|
|