|
|
@ -24,13 +24,8 @@ class PLMSSchedule:
|
|
|
|
ddpm_num_timesteps, # 1000?
|
|
|
|
ddpm_num_timesteps, # 1000?
|
|
|
|
ddim_num_steps, # prompt.steps?
|
|
|
|
ddim_num_steps, # prompt.steps?
|
|
|
|
alphas_cumprod,
|
|
|
|
alphas_cumprod,
|
|
|
|
alphas_cumprod_prev,
|
|
|
|
|
|
|
|
betas,
|
|
|
|
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_eta=0.0,
|
|
|
|
|
|
|
|
):
|
|
|
|
):
|
|
|
|
if ddim_eta != 0:
|
|
|
|
|
|
|
|
raise ValueError("ddim_eta must be 0 for PLMS")
|
|
|
|
|
|
|
|
device = get_device()
|
|
|
|
device = get_device()
|
|
|
|
|
|
|
|
|
|
|
|
assert (
|
|
|
|
assert (
|
|
|
@ -40,19 +35,12 @@ class PLMSSchedule:
|
|
|
|
def to_torch(x):
|
|
|
|
def to_torch(x):
|
|
|
|
return x.clone().detach().to(torch.float32).to(device)
|
|
|
|
return x.clone().detach().to(torch.float32).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
self.betas = to_torch(betas)
|
|
|
|
|
|
|
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
|
|
|
self.alphas_cumprod = to_torch(alphas_cumprod)
|
|
|
|
self.alphas_cumprod_prev = to_torch(alphas_cumprod_prev)
|
|
|
|
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
|
|
|
self.sqrt_alphas_cumprod = to_torch(np.sqrt(alphas_cumprod.cpu()))
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
self.sqrt_one_minus_alphas_cumprod = to_torch(
|
|
|
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
|
|
|
np.sqrt(1.0 - alphas_cumprod.cpu())
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.log_one_minus_alphas_cumprod = to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
|
|
|
|
|
|
|
self.sqrt_recip_alphas_cumprod = to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
|
|
|
|
|
|
|
self.sqrt_recipm1_alphas_cumprod = to_torch(
|
|
|
|
|
|
|
|
np.sqrt(1.0 / alphas_cumprod.cpu() - 1)
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
self.ddim_timesteps = make_ddim_timesteps(
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
|
ddim_discr_method=ddim_discretize,
|
|
|
@ -64,7 +52,7 @@ class PLMSSchedule:
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
|
|
|
alphacums=alphas_cumprod.cpu(),
|
|
|
|
alphacums=alphas_cumprod.cpu(),
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
ddim_timesteps=self.ddim_timesteps,
|
|
|
|
eta=ddim_eta,
|
|
|
|
eta=0.0,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
|
|
|
self.ddim_sigmas = ddim_sigmas.to(torch.float32).to(torch.device(device))
|
|
|
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
|
|
|
self.ddim_alphas = ddim_alphas.to(torch.float32).to(torch.device(device))
|
|
|
@ -113,10 +101,7 @@ class PLMSSampler:
|
|
|
|
ddpm_num_timesteps=self.model.num_timesteps,
|
|
|
|
ddpm_num_timesteps=self.model.num_timesteps,
|
|
|
|
ddim_num_steps=num_steps,
|
|
|
|
ddim_num_steps=num_steps,
|
|
|
|
alphas_cumprod=self.model.alphas_cumprod,
|
|
|
|
alphas_cumprod=self.model.alphas_cumprod,
|
|
|
|
alphas_cumprod_prev=self.model.alphas_cumprod_prev,
|
|
|
|
|
|
|
|
betas=self.model.betas,
|
|
|
|
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_discretize="uniform",
|
|
|
|
ddim_eta=0.0,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
device = self.device
|
|
|
|
device = self.device
|
|
|
|
# batch_size = shape[0]
|
|
|
|
# batch_size = shape[0]
|
|
|
|