refactor: remove unused parameters

pull/57/head
Bryce 2 years ago committed by Bryce Drennan
parent 7ae77faf07
commit b0123a8f43

@ -280,14 +280,12 @@ def imagine(
# encode (scaled latent)
seed_everything(prompt.seed)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
# todo: this isn't the right scheduler for everything...
schedule = PLMSSchedule(
ddpm_num_timesteps=model.num_timesteps,
ddim_num_steps=prompt.steps,
alphas_cumprod=model.alphas_cumprod,
alphas_cumprod_prev=model.alphas_cumprod_prev,
betas=model.betas,
ddim_discretize="uniform",
ddim_eta=0.0,
)
if generation_strength >= 1:
# prompt strength gets converted to time encodings,

@ -169,15 +169,6 @@ class DDPM(pl.LightningModule):
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (

@ -23,8 +23,6 @@ class DDIMSchedule:
self,
model_num_timesteps,
model_alphas_cumprod,
model_alphas_cumprod_prev,
model_betas,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,
@ -49,19 +47,12 @@ class DDIMSchedule:
eta=ddim_eta,
)
self.ddim_timesteps = ddim_timesteps
self.betas = to_torch(model_betas)
self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
)
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
self.ddim_alphas_prev = ddim_alphas_prev
@ -109,8 +100,6 @@ class DDIMSampler:
schedule = DDIMSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
model_betas=self.model.betas,
ddim_num_steps=num_steps,
ddim_discretize="uniform",
ddim_eta=0.0,

@ -24,13 +24,8 @@ class PLMSSchedule:
ddpm_num_timesteps, # 1000?
ddim_num_steps, # prompt.steps?
alphas_cumprod,
alphas_cumprod_prev,
betas,
ddim_discretize="uniform",
ddim_eta=0.0,
):
if ddim_eta != 0:
raise ValueError("ddim_eta must be 0 for PLMS")
device = get_device()
assert (
@ -40,19 +35,12 @@ class PLMSSchedule:
def to_torch(x):
return x.clone().detach().to(torch.float32).to(device)
self.betas = to_torch(betas)
self.alphas_cumprod = to_torch(alphas_cumprod)
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
self.sqrt_one_minus_alphas_cumprod = to_torch(
np.sqrt(1.0 - alphas_cumprod.cpu())
)
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
self.sqrt_recipm1_alphas_cumprod = to_torch(
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
)
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
@ -64,7 +52,7 @@ class PLMSSchedule:
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
eta=0.0,
)
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
@ -113,10 +101,7 @@ class PLMSSampler:
ddpm_num_timesteps=self.model.num_timesteps,
ddim_num_steps=num_steps,
alphas_cumprod=self.model.alphas_cumprod,
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
betas=self.model.betas,
ddim_discretize="uniform",
ddim_eta=0.0,
)
device = self.device
# batch_size = shape[0]

Loading…
Cancel
Save