You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/vendored/refiners/foundationals/latent_diffusion/schedulers/euler.py

89 lines
3.2 KiB
Python

import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
first_inference_step: int = 0,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
first_inference_step=first_inference_step,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()
@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()
def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps
def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)
def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
sigma = self.sigmas[step]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise
# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt
return denoised_x