refactor: combine identical schedules

pull/60/head
Bryce 2 years ago committed by Bryce Drennan
parent 8d4b5cb9e1
commit a105dadbc4

@ -25,8 +25,7 @@ from imaginairy.log_utils import (
log_latent,
)
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers.base import get_sampler
from imaginairy.samplers.plms import PLMSSchedule
from imaginairy.samplers.base import NoiseSchedule, get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_group_norm,
@ -277,8 +276,8 @@ def imagine(
# encode (scaled latent)
seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
# todo: this isn't the right scheduler for everything...
schedule = PLMSSchedule(
schedule = NoiseSchedule(
model_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
model_alphas_cumprod=model.alphas_cumprod,

@ -1,8 +1,14 @@
# pylama:ignore=W0613
import numpy as np
import torch
from torch import nn
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
)
from imaginairy.utils import get_device
SAMPLER_TYPE_OPTIONS = [
"plms",
@ -107,3 +113,47 @@ def get_noise_prediction(
log_latent(noise_pred, "noise_pred")
return noise_pred
def to_torch(x):
return x.clone().detach().to(torch.float32).to(get_device())
class NoiseSchedule:
def __init__(
self,
model_num_timesteps,
model_alphas_cumprod,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
):
device = get_device()
if model_alphas_cumprod.shape[0] != model_num_timesteps:
raise ValueError("alphas have to be defined for each timestep")
self.alphas_cumprod = to_torch(model_alphas_cumprod)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - model_alphas_cumprod.cpu())
)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=model_num_timesteps,
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=model_alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
)

@ -6,64 +6,13 @@ import torch
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
def to_torch(x):
return x.clone().detach().to(torch.float32).to(get_device())
class DDIMSchedule:
def __init__(
self,
model_num_timesteps,
model_alphas_cumprod,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
):
device = get_device()
if not model_alphas_cumprod.shape[0] == model_num_timesteps:
raise ValueError("alphas have to be defined for each timestep")
self.alphas_cumprod = to_torch(model_alphas_cumprod)
ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=model_num_timesteps,
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=model_alphas_cumprod.cpu(),
ddim_timesteps=ddim_timesteps,
eta=ddim_eta,
)
self.ddim_timesteps = ddim_timesteps
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - model_alphas_cumprod.cpu())
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
)
class DDIMSampler:
"""
Denoising Diffusion Implicit Models
@ -96,7 +45,7 @@ class DDIMSampler:
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
schedule = DDIMSchedule(
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
ddim_num_steps=num_steps,

@ -6,61 +6,13 @@ import torch
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
)
from imaginairy.samplers.base import get_noise_prediction
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
def to_torch(x):
return x.clone().detach().to(torch.float32).to(get_device())
class PLMSSchedule:
def __init__(
self,
model_num_timesteps, # 1000?
model_alphas_cumprod,
ddim_num_steps, # prompt.steps?
ddim_discretize="uniform",
):
device = get_device()
if model_alphas_cumprod.shape[0] != model_num_timesteps:
raise ValueError("alphas have to be defined for each timestep")
self.alphas_cumprod = to_torch(model_alphas_cumprod)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(model_alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - model_alphas_cumprod.cpu())
)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=model_num_timesteps,
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=model_alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=0.0,
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
self.ddim_alphas_prev = ddim_alphas_prev
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(torch.device(device))
)
class PLMSSampler:
"""
probabilistic least-mean-squares
@ -97,7 +49,7 @@ class PLMSSampler:
f"Got {positive_conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
schedule = PLMSSchedule(
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,
ddim_num_steps=num_steps,
model_alphas_cumprod=self.model.alphas_cumprod,
@ -163,7 +115,7 @@ class PLMSSampler:
positive_conditioning,
guidance_scale,
time_encoding,
schedule: PLMSSchedule,
schedule: NoiseSchedule,
index,
repeat_noise=False,
quantize_denoised=False,

Loading…
Cancel
Save