|
|
|
@ -24,6 +24,7 @@ class DPMSolver(Scheduler):
|
|
|
|
|
final_diffusion_rate: float = 1.2e-2,
|
|
|
|
|
last_step_first_order: bool = False,
|
|
|
|
|
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
|
|
|
|
|
first_inference_step: int = 0,
|
|
|
|
|
device: Device | str = "cpu",
|
|
|
|
|
dtype: Dtype = float32,
|
|
|
|
|
):
|
|
|
|
@ -33,12 +34,12 @@ class DPMSolver(Scheduler):
|
|
|
|
|
initial_diffusion_rate=initial_diffusion_rate,
|
|
|
|
|
final_diffusion_rate=final_diffusion_rate,
|
|
|
|
|
noise_schedule=noise_schedule,
|
|
|
|
|
first_inference_step=first_inference_step,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
)
|
|
|
|
|
self.estimated_data = deque([tensor([])] * 2, maxlen=2)
|
|
|
|
|
self.last_step_first_order = last_step_first_order
|
|
|
|
|
self._first_step_has_been_run = False
|
|
|
|
|
|
|
|
|
|
def _generate_timesteps(self) -> Tensor:
|
|
|
|
|
# We need to use numpy here because:
|
|
|
|
@ -81,6 +82,7 @@ class DPMSolver(Scheduler):
|
|
|
|
|
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
|
|
|
|
|
previous_noise_std = self.noise_std[previous_timestep]
|
|
|
|
|
current_noise_std = self.noise_std[current_timestep]
|
|
|
|
|
|
|
|
|
|
estimation_delta = (current_data_estimation - next_data_estimation) / (
|
|
|
|
|
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
|
|
|
|
|
)
|
|
|
|
@ -100,13 +102,14 @@ class DPMSolver(Scheduler):
|
|
|
|
|
backward Euler update, which is a numerical method commonly used to solve ordinary differential equations
|
|
|
|
|
(ODEs).
|
|
|
|
|
"""
|
|
|
|
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
|
|
|
|
|
|
|
|
|
current_timestep = self.timesteps[step]
|
|
|
|
|
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
|
|
|
|
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
|
|
|
|
self.estimated_data.append(estimated_denoised_data)
|
|
|
|
|
|
|
|
|
|
if step == 0 or (self.last_step_first_order and step == self.num_inference_steps - 1) or not self._first_step_has_been_run:
|
|
|
|
|
self._first_step_has_been_run = True
|
|
|
|
|
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
|
|
|
|
return self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
|
|
|
|
|
|
|
|
|
|
return self.multistep_dpm_solver_second_order_update(x=x, step=step)
|
|
|
|
|