From b0123a8f43c53440826e5709b2bd08297fd978db Mon Sep 17 00:00:00 2001 From: Bryce Date: Mon, 10 Oct 2022 21:06:52 -0700 Subject: [PATCH] refactor: remove unused parameters --- imaginairy/api.py | 4 +--- imaginairy/modules/diffusion/ddpm.py | 9 --------- imaginairy/samplers/ddim.py | 11 ----------- imaginairy/samplers/plms.py | 17 +---------------- 4 files changed, 2 insertions(+), 39 deletions(-) diff --git a/imaginairy/api.py b/imaginairy/api.py index 704cb53..eced708 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -280,14 +280,12 @@ def imagine( # encode (scaled latent) seed_everything(prompt.seed) noise = torch.randn_like(init_latent, device="cpu").to(get_device()) + # todo: this isn't the right scheduler for everything... schedule = PLMSSchedule( ddpm_num_timesteps=model.num_timesteps, ddim_num_steps=prompt.steps, alphas_cumprod=model.alphas_cumprod, - alphas_cumprod_prev=model.alphas_cumprod_prev, - betas=model.betas, ddim_discretize="uniform", - ddim_eta=0.0, ) if generation_strength >= 1: # prompt strength gets converted to time encodings, diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index 456f5e9..d6c8c69 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -169,15 +169,6 @@ class DDPM(pl.LightningModule): self.register_buffer( "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) - ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 63e4bb4..4c1782a 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -23,8 +23,6 @@ class DDIMSchedule: self, model_num_timesteps, model_alphas_cumprod, - model_alphas_cumprod_prev, - model_betas, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, @@ -49,19 +47,12 @@ class DDIMSchedule: eta=ddim_eta, ) self.ddim_timesteps = ddim_timesteps - self.betas = to_torch(model_betas) self.alphas_cumprod = to_torch(alphas_cumprod) - self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) self.sqrt_one_minus_alphas_cumprod = to_torch( np.sqrt(1.0 - alphas_cumprod.cpu()) ) - self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu())) - self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - self.sqrt_recipm1_alphas_cumprod = to_torch( - np.sqrt(1.0 / alphas_cumprod.cpu() - 1) - ) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device) self.ddim_alphas = ddim_alphas.to(torch.float32).to(device) self.ddim_alphas_prev = ddim_alphas_prev @@ -109,8 +100,6 @@ class DDIMSampler: schedule = DDIMSchedule( model_num_timesteps=self.model.num_timesteps, model_alphas_cumprod=self.model.alphas_cumprod, - model_alphas_cumprod_prev=self.model.alphas_cumprod_prev, - model_betas=self.model.betas, ddim_num_steps=num_steps, ddim_discretize="uniform", ddim_eta=0.0, diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index 8423964..a4c5912 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -24,13 +24,8 @@ class PLMSSchedule: ddpm_num_timesteps, # 1000? ddim_num_steps, # prompt.steps? alphas_cumprod, - alphas_cumprod_prev, - betas, ddim_discretize="uniform", - ddim_eta=0.0, ): - if ddim_eta != 0: - raise ValueError("ddim_eta must be 0 for PLMS") device = get_device() assert ( @@ -40,19 +35,12 @@ class PLMSSchedule: def to_torch(x): return x.clone().detach().to(torch.float32).to(device) - self.betas = to_torch(betas) self.alphas_cumprod = to_torch(alphas_cumprod) - self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev) # calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) self.sqrt_one_minus_alphas_cumprod = to_torch( np.sqrt(1.0 - alphas_cumprod.cpu()) ) - self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu())) - self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) - self.sqrt_recipm1_alphas_cumprod = to_torch( - np.sqrt(1.0 / alphas_cumprod.cpu() - 1) - ) self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, @@ -64,7 +52,7 @@ class PLMSSchedule: ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta, + eta=0.0, ) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device)) self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device)) @@ -113,10 +101,7 @@ class PLMSSampler: ddpm_num_timesteps=self.model.num_timesteps, ddim_num_steps=num_steps, alphas_cumprod=self.model.alphas_cumprod, - alphas_cumprod_prev=self.model.alphas_cumprod_prev, - betas=self.model.betas, ddim_discretize="uniform", - ddim_eta=0.0, ) device = self.device # batch_size = shape[0]