imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/ddpm.py
jaydrennan 1bf53e47cf
feature: updates refiners vendored library (#458)
* feature: updates refiners vendored library

has a small bugfix that will soon be replaced by a better fix from upstream refiners

Co-authored-by: Bryce <github20210803@accounts.brycedrennan.com>
2024-01-19 08:45:23 -08:00

35 lines
1.3 KiB
Python

from torch import Generator, Tensor, arange, device as Device
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
class DDPM(Scheduler):
"""
The Denoising Diffusion Probabilistic Models (DDPM) is a specific type of diffusion model,
which uses a specific strategy to generate the timesteps and applies the diffusion process in a specific way.
"""
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
) -> None:
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
device=device,
)
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError