imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/scheduler.py
Bryce 55e27160f5 build: vendorize refiners
so we can still work in conda envs
2024-01-02 22:02:31 -08:00

129 lines
4.8 KiB
Python

from abc import ABC, abstractmethod
from enum import Enum
from typing import TypeVar
from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt
T = TypeVar("T", bound="Scheduler")
class NoiseSchedule(str, Enum):
UNIFORM = "uniform"
QUADRATIC = "quadratic"
KARRAS = "karras"
class Scheduler(ABC):
"""
A base class for creating a diffusion model scheduler.
The Scheduler creates a sequence of noise and scaling factors used in the diffusion process,
which gradually transforms the original data distribution into a Gaussian one.
This process is described using several parameters such as initial and final diffusion rates,
and is encapsulated into a `__call__` method that applies a step of the diffusion process.
"""
timesteps: Tensor
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.device: Device = Device(device)
self.dtype: DType = dtype
self.num_inference_steps = num_inference_steps
self.num_train_timesteps = num_train_timesteps
self.initial_diffusion_rate = initial_diffusion_rate
self.final_diffusion_rate = final_diffusion_rate
self.noise_schedule = noise_schedule
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()
@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
This method should be overridden by subclasses to implement the specific diffusion process.
"""
...
@abstractmethod
def _generate_timesteps(self) -> Tensor:
"""
Generates a tensor of timesteps.
This method should be overridden by subclasses to provide the specific timesteps for the diffusion process.
"""
...
@property
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))
def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
case "uniform":
return 1 - self.sample_power_distribution(1)
case "quadratic":
return 1 - self.sample_power_distribution(2)
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
def add_noise(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf. Useful to preview progress or for guidance like
# in https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
return self