mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
cf8a44b317
better handles img2img (partial diffusion runs)
89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor
|
|
|
|
from imaginairy.vendored.refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
|
|
|
|
|
|
class EulerScheduler(Scheduler):
|
|
def __init__(
|
|
self,
|
|
num_inference_steps: int,
|
|
num_train_timesteps: int = 1_000,
|
|
initial_diffusion_rate: float = 8.5e-4,
|
|
final_diffusion_rate: float = 1.2e-2,
|
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
first_inference_step: int = 0,
|
|
device: Device | str = "cpu",
|
|
dtype: Dtype = float32,
|
|
):
|
|
if noise_schedule != NoiseSchedule.QUADRATIC:
|
|
raise NotImplementedError
|
|
super().__init__(
|
|
num_inference_steps=num_inference_steps,
|
|
num_train_timesteps=num_train_timesteps,
|
|
initial_diffusion_rate=initial_diffusion_rate,
|
|
final_diffusion_rate=final_diffusion_rate,
|
|
noise_schedule=noise_schedule,
|
|
first_inference_step=first_inference_step,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.sigmas = self._generate_sigmas()
|
|
|
|
@property
|
|
def init_noise_sigma(self) -> Tensor:
|
|
return self.sigmas.max()
|
|
|
|
def _generate_timesteps(self) -> Tensor:
|
|
# We need to use numpy here because:
|
|
# numpy.linspace(0,999,31)[15] is 499.49999999999994
|
|
# torch.linspace(0,999,31)[15] is 499.5
|
|
# ...and we want the same result as the original codebase.
|
|
timesteps = torch.tensor(
|
|
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
|
|
).flip(0)
|
|
return timesteps
|
|
|
|
def _generate_sigmas(self) -> Tensor:
|
|
sigmas = self.noise_std / self.cumulative_scale_factors
|
|
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
|
|
sigmas = torch.cat([sigmas, tensor([0.0])])
|
|
return sigmas.to(device=self.device, dtype=self.dtype)
|
|
|
|
def scale_model_input(self, x: Tensor, step: int) -> Tensor:
|
|
sigma = self.sigmas[step]
|
|
return x / ((sigma**2 + 1) ** 0.5)
|
|
|
|
def __call__(
|
|
self,
|
|
x: Tensor,
|
|
noise: Tensor,
|
|
step: int,
|
|
generator: Generator | None = None,
|
|
s_churn: float = 0.0,
|
|
s_tmin: float = 0.0,
|
|
s_tmax: float = float("inf"),
|
|
s_noise: float = 1.0,
|
|
) -> Tensor:
|
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
|
|
|
sigma = self.sigmas[step]
|
|
|
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
|
|
|
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
|
eps = alt_noise * s_noise
|
|
sigma_hat = sigma * (gamma + 1)
|
|
if gamma > 0:
|
|
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
|
|
|
predicted_x = x - sigma_hat * noise
|
|
|
|
# 1st order Euler
|
|
derivative = (x - predicted_x) / sigma_hat
|
|
dt = self.sigmas[step + 1] - sigma_hat
|
|
denoised_x = x + derivative * dt
|
|
|
|
return denoised_x
|