mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
refactor: remove unused parameters
This commit is contained in:
parent
7ae77faf07
commit
b0123a8f43
@ -280,14 +280,12 @@ def imagine(
|
|||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
seed_everything(prompt.seed)
|
seed_everything(prompt.seed)
|
||||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||||
|
# todo: this isn't the right scheduler for everything...
|
||||||
schedule = PLMSSchedule(
|
schedule = PLMSSchedule(
|
||||||
ddpm_num_timesteps=model.num_timesteps,
|
ddpm_num_timesteps=model.num_timesteps,
|
||||||
ddim_num_steps=prompt.steps,
|
ddim_num_steps=prompt.steps,
|
||||||
alphas_cumprod=model.alphas_cumprod,
|
alphas_cumprod=model.alphas_cumprod,
|
||||||
alphas_cumprod_prev=model.alphas_cumprod_prev,
|
|
||||||
betas=model.betas,
|
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
|
||||||
)
|
)
|
||||||
if generation_strength >= 1:
|
if generation_strength >= 1:
|
||||||
# prompt strength gets converted to time encodings,
|
# prompt strength gets converted to time encodings,
|
||||||
|
@ -169,15 +169,6 @@ class DDPM(pl.LightningModule):
|
|||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
|
||||||
)
|
)
|
||||||
self.register_buffer(
|
|
||||||
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||||
posterior_variance = (1 - self.v_posterior) * betas * (
|
posterior_variance = (1 - self.v_posterior) * betas * (
|
||||||
|
@ -23,8 +23,6 @@ class DDIMSchedule:
|
|||||||
self,
|
self,
|
||||||
model_num_timesteps,
|
model_num_timesteps,
|
||||||
model_alphas_cumprod,
|
model_alphas_cumprod,
|
||||||
model_alphas_cumprod_prev,
|
|
||||||
model_betas,
|
|
||||||
ddim_num_steps,
|
ddim_num_steps,
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
ddim_eta=0.0,
|
||||||
@ -49,19 +47,12 @@ class DDIMSchedule:
|
|||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
)
|
)
|
||||||
self.ddim_timesteps = ddim_timesteps
|
self.ddim_timesteps = ddim_timesteps
|
||||||
self.betas = to_torch(model_betas)
|
|
||||||
self.alphas_cumprod = to_torch(alphas_cumprod)
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
||||||
self.alphas_cumprod_prev = to_torch(model_alphas_cumprod_prev)
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
||||||
np.sqrt(1.0 - alphas_cumprod.cpu())
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
||||||
)
|
)
|
||||||
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
|
||||||
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
|
||||||
self.sqrt_recipm1_alphas_cumprod = to_torch(
|
|
||||||
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
|
||||||
)
|
|
||||||
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(device)
|
||||||
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(device)
|
||||||
self.ddim_alphas_prev = ddim_alphas_prev
|
self.ddim_alphas_prev = ddim_alphas_prev
|
||||||
@ -109,8 +100,6 @@ class DDIMSampler:
|
|||||||
schedule = DDIMSchedule(
|
schedule = DDIMSchedule(
|
||||||
model_num_timesteps=self.model.num_timesteps,
|
model_num_timesteps=self.model.num_timesteps,
|
||||||
model_alphas_cumprod=self.model.alphas_cumprod,
|
model_alphas_cumprod=self.model.alphas_cumprod,
|
||||||
model_alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
|
||||||
model_betas=self.model.betas,
|
|
||||||
ddim_num_steps=num_steps,
|
ddim_num_steps=num_steps,
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
ddim_eta=0.0,
|
||||||
|
@ -24,13 +24,8 @@ class PLMSSchedule:
|
|||||||
ddpm_num_timesteps, # 1000?
|
ddpm_num_timesteps, # 1000?
|
||||||
ddim_num_steps, # prompt.steps?
|
ddim_num_steps, # prompt.steps?
|
||||||
alphas_cumprod,
|
alphas_cumprod,
|
||||||
alphas_cumprod_prev,
|
|
||||||
betas,
|
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
|
||||||
):
|
):
|
||||||
if ddim_eta != 0:
|
|
||||||
raise ValueError("ddim_eta must be 0 for PLMS")
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -40,19 +35,12 @@ class PLMSSchedule:
|
|||||||
def to_torch(x):
|
def to_torch(x):
|
||||||
return x.clone().detach().to(torch.float32).to(device)
|
return x.clone().detach().to(torch.float32).to(device)
|
||||||
|
|
||||||
self.betas = to_torch(betas)
|
|
||||||
self.alphas_cumprod = to_torch(alphas_cumprod)
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
||||||
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||||
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
||||||
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
||||||
np.sqrt(1.0 - alphas_cumprod.cpu())
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
||||||
)
|
)
|
||||||
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
|
||||||
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
|
||||||
self.sqrt_recipm1_alphas_cumprod = to_torch(
|
|
||||||
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
@ -64,7 +52,7 @@ class PLMSSchedule:
|
|||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||||
alphacums=alphas_cumprod.cpu(),
|
alphacums=alphas_cumprod.cpu(),
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
ddim_timesteps=self.ddim_timesteps,
|
||||||
eta=ddim_eta,
|
eta=0.0,
|
||||||
)
|
)
|
||||||
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
||||||
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
||||||
@ -113,10 +101,7 @@ class PLMSSampler:
|
|||||||
ddpm_num_timesteps=self.model.num_timesteps,
|
ddpm_num_timesteps=self.model.num_timesteps,
|
||||||
ddim_num_steps=num_steps,
|
ddim_num_steps=num_steps,
|
||||||
alphas_cumprod=self.model.alphas_cumprod,
|
alphas_cumprod=self.model.alphas_cumprod,
|
||||||
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
|
||||||
betas=self.model.betas,
|
|
||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.0,
|
|
||||||
)
|
)
|
||||||
device = self.device
|
device = self.device
|
||||||
# batch_size = shape[0]
|
# batch_size = shape[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user