refactor: remove unused parameters

This commit is contained in:
Bryce 2022-10-10 21:06:52 -07:00 committed by Bryce Drennan
parent 7ae77faf07
commit b0123a8f43
4 changed files with 2 additions and 39 deletions

View File

@ -280,14 +280,12 @@ def imagine(
# encode (scaled latent) # encode (scaled latent)
seed_everything(prompt.seed) seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device()) noise = torch.randn_like(init_latent, device="cpu").to(get_device())
# todo: this isn't the right scheduler for everything...
schedule = PLMSSchedule( schedule = PLMSSchedule(
ddpm_num_timesteps=model.num_timesteps, ddpm_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps, ddim_num_steps=prompt.steps,
alphas_cumprod=model.alphas_cumprod, alphas_cumprod=model.alphas_cumprod,
alphas_cumprod_prev=model.alphas_cumprod_prev,
betas=model.betas,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0,
) )
if generation_strength >= 1: if generation_strength >= 1:
# prompt strength gets converted to time encodings, # prompt strength gets converted to time encodings,

View File

@ -169,15 +169,6 @@ class DDPM(pl.LightningModule):
self.register_buffer( self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
) )
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * ( posterior_variance = (1 - self.v_posterior) * betas * (

View File

@ -23,8 +23,6 @@ class DDIMSchedule:
self, self,
model_num_timesteps, model_num_timesteps,
model_alphas_cumprod, model_alphas_cumprod,
model_alphas_cumprod_prev,
model_betas,
ddim_num_steps, ddim_num_steps,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,
@ -49,19 +47,12 @@ class DDIMSchedule:
eta=ddim_eta, eta=ddim_eta,
) )
self.ddim_timesteps = ddim_timesteps self.ddim_timesteps = ddim_timesteps
self.betas = to_torch(model_betas)
self.alphas_cumprod = to_torch(alphas_cumprod) self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch( self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu()) np.sqrt(1.0 - alphas_cumprod.cpu())
) )
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device) self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev self.ddim_alphas_prev = ddim_alphas_prev
@ -109,8 +100,6 @@ class DDIMSampler:
schedule = DDIMSchedule( schedule = DDIMSchedule(
model_num_timesteps=self.model.num_timesteps, model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod, model_alphas_cumprod=self.model.alphas_cumprod,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
model_betas=self.model.betas,
ddim_num_steps=num_steps, ddim_num_steps=num_steps,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0, ddim_eta=0.0,

View File

@ -24,13 +24,8 @@ class PLMSSchedule:
ddpm_num_timesteps, # 1000? ddpm_num_timesteps, # 1000?
ddim_num_steps, # prompt.steps? ddim_num_steps, # prompt.steps?
alphas_cumprod, alphas_cumprod,
alphas_cumprod_prev,
betas,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0,
): ):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
device = get_device() device = get_device()
assert ( assert (
@ -40,19 +35,12 @@ class PLMSSchedule:
def to_torch(x): def to_torch(x):
return x.clone().detach().to(torch.float32).to(device) return x.clone().detach().to(torch.float32).to(device)
self.betas = to_torch(betas)
self.alphas_cumprod = to_torch(alphas_cumprod) self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others # calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu())) self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch( self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu()) np.sqrt(1.0 - alphas_cumprod.cpu())
) )
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_timesteps = make_ddim_timesteps( self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize, ddim_discr_method=ddim_discretize,
@ -64,7 +52,7 @@ class PLMSSchedule:
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(), alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps, ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, eta=0.0,
) )
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device)) self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device)) self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
@ -113,10 +101,7 @@ class PLMSSampler:
ddpm_num_timesteps=self.model.num_timesteps, ddpm_num_timesteps=self.model.num_timesteps,
ddim_num_steps=num_steps, ddim_num_steps=num_steps,
alphas_cumprod=self.model.alphas_cumprod, alphas_cumprod=self.model.alphas_cumprod,
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
betas=self.model.betas,
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0.0,
) )
device = self.device device = self.device
# batch_size = shape[0] # batch_size = shape[0]