|
|
|
@ -28,6 +28,15 @@ def get_sigmas_exponential(n, sigma_min, sigma_max, device="cpu"):
|
|
|
|
|
return append_zero(sigmas)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1.0, device="cpu"):
|
|
|
|
|
"""Constructs an polynomial in log sigma noise schedule."""
|
|
|
|
|
ramp = torch.linspace(1, 0, n, device=device) ** rho
|
|
|
|
|
sigmas = torch.exp(
|
|
|
|
|
ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min)
|
|
|
|
|
)
|
|
|
|
|
return append_zero(sigmas)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device="cpu"):
|
|
|
|
|
"""Constructs a continuous VP noise schedule."""
|
|
|
|
|
t = torch.linspace(1, eps_s, n, device=device)
|
|
|
|
@ -54,6 +63,68 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
|
|
|
|
|
return sigma_down, sigma_up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_noise_sampler(x):
|
|
|
|
|
return lambda sigma, sigma_next: torch.randn_like(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BatchedBrownianTree:
|
|
|
|
|
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
|
|
|
|
t0, t1, self.sign = self.sort(t0, t1)
|
|
|
|
|
w0 = kwargs.get("w0", torch.zeros_like(x))
|
|
|
|
|
if seed is None:
|
|
|
|
|
seed = torch.randint(0, 2**63 - 1, []).item()
|
|
|
|
|
self.batched = True
|
|
|
|
|
try:
|
|
|
|
|
assert len(seed) == x.shape[0]
|
|
|
|
|
w0 = w0[0]
|
|
|
|
|
except TypeError:
|
|
|
|
|
seed = [seed]
|
|
|
|
|
self.batched = False
|
|
|
|
|
self.trees = [
|
|
|
|
|
torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def sort(a, b):
|
|
|
|
|
return (a, b, 1) if a < b else (b, a, -1)
|
|
|
|
|
|
|
|
|
|
def __call__(self, t0, t1):
|
|
|
|
|
t0, t1, sign = self.sort(t0, t1)
|
|
|
|
|
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
|
|
|
|
return w if self.batched else w[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BrownianTreeNoiseSampler:
|
|
|
|
|
"""A noise sampler backed by a torchsde.BrownianTree.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
|
|
|
|
random samples.
|
|
|
|
|
sigma_min (float): The low end of the valid interval.
|
|
|
|
|
sigma_max (float): The high end of the valid interval.
|
|
|
|
|
seed (int or List[int]): The random seed. If a list of seeds is
|
|
|
|
|
supplied instead of a single integer, then the noise sampler will
|
|
|
|
|
use one BrownianTree per batch item, each with its own seed.
|
|
|
|
|
transform (callable): A function that maps sigma to the sampler's
|
|
|
|
|
internal timestep.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
|
|
|
|
self.transform = transform
|
|
|
|
|
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(
|
|
|
|
|
torch.as_tensor(sigma_max)
|
|
|
|
|
)
|
|
|
|
|
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
|
|
|
|
|
|
|
|
|
def __call__(self, sigma, sigma_next):
|
|
|
|
|
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(
|
|
|
|
|
torch.as_tensor(sigma_next)
|
|
|
|
|
)
|
|
|
|
|
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_euler(
|
|
|
|
|
model,
|
|
|
|
@ -100,10 +171,19 @@ def sample_euler(
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_euler_ancestral(
|
|
|
|
|
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0
|
|
|
|
|
model,
|
|
|
|
|
x,
|
|
|
|
|
sigmas,
|
|
|
|
|
extra_args=None,
|
|
|
|
|
callback=None,
|
|
|
|
|
disable=None,
|
|
|
|
|
eta=1.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
):
|
|
|
|
|
"""Ancestral sampling with Euler method steps."""
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
@ -122,7 +202,8 @@ def sample_euler_ancestral(
|
|
|
|
|
# Euler method
|
|
|
|
|
dt = sigma_down - sigmas[i]
|
|
|
|
|
x = x + d * dt
|
|
|
|
|
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
|
|
|
|
|
if sigmas[i + 1] > 0:
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -234,10 +315,19 @@ def sample_dpm_2(
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_dpm_2_ancestral(
|
|
|
|
|
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0
|
|
|
|
|
model,
|
|
|
|
|
x,
|
|
|
|
|
sigmas,
|
|
|
|
|
extra_args=None,
|
|
|
|
|
callback=None,
|
|
|
|
|
disable=None,
|
|
|
|
|
eta=1.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
):
|
|
|
|
|
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
|
|
|
|
"""Ancestral sampling with DPM-Solver second-order steps."""
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
@ -266,7 +356,7 @@ def sample_dpm_2_ancestral(
|
|
|
|
|
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
|
|
|
|
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
|
|
|
|
x = x + d_2 * dt_2
|
|
|
|
|
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -452,7 +542,12 @@ class DPMSolver(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
return x_3, eps_cache
|
|
|
|
|
|
|
|
|
|
def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0.0, s_noise=1.0):
|
|
|
|
|
def dpm_solver_fast(
|
|
|
|
|
self, x, t_start, t_end, nfe, eta=0.0, s_noise=1.0, noise_sampler=None
|
|
|
|
|
):
|
|
|
|
|
noise_sampler = (
|
|
|
|
|
default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
|
|
)
|
|
|
|
|
if not t_end > t_start and eta:
|
|
|
|
|
raise ValueError("eta must be 0 for reverse sampling")
|
|
|
|
|
|
|
|
|
@ -494,7 +589,7 @@ class DPMSolver(nn.Module):
|
|
|
|
|
x, t, t_next_, eps_cache=eps_cache
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
x = x + su * s_noise * torch.randn_like(x)
|
|
|
|
|
x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
|
|
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -513,7 +608,11 @@ class DPMSolver(nn.Module):
|
|
|
|
|
accept_safety=0.81,
|
|
|
|
|
eta=0.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
):
|
|
|
|
|
noise_sampler = (
|
|
|
|
|
default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
|
|
)
|
|
|
|
|
if order not in {2, 3}:
|
|
|
|
|
raise ValueError("order should be 2 or 3")
|
|
|
|
|
forward = t_end > t_start
|
|
|
|
@ -564,7 +663,7 @@ class DPMSolver(nn.Module):
|
|
|
|
|
accept = pid.propose_step(error)
|
|
|
|
|
if accept:
|
|
|
|
|
x_prev = x_low
|
|
|
|
|
x = x_high + su * s_noise * torch.randn_like(x_high)
|
|
|
|
|
x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
|
|
|
|
|
s = t
|
|
|
|
|
info["n_accept"] += 1
|
|
|
|
|
else:
|
|
|
|
@ -601,6 +700,7 @@ def sample_dpm_fast(
|
|
|
|
|
disable=None,
|
|
|
|
|
eta=0.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
):
|
|
|
|
|
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
|
|
|
|
|
if sigma_min <= 0 or sigma_max <= 0:
|
|
|
|
@ -622,6 +722,7 @@ def sample_dpm_fast(
|
|
|
|
|
n,
|
|
|
|
|
eta,
|
|
|
|
|
s_noise,
|
|
|
|
|
noise_sampler,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -644,6 +745,7 @@ def sample_dpm_adaptive(
|
|
|
|
|
accept_safety=0.81,
|
|
|
|
|
eta=0.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
return_info=False,
|
|
|
|
|
):
|
|
|
|
|
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
|
|
|
|
@ -673,6 +775,7 @@ def sample_dpm_adaptive(
|
|
|
|
|
accept_safety,
|
|
|
|
|
eta,
|
|
|
|
|
s_noise,
|
|
|
|
|
noise_sampler,
|
|
|
|
|
)
|
|
|
|
|
if return_info:
|
|
|
|
|
return x, info
|
|
|
|
@ -681,10 +784,19 @@ def sample_dpm_adaptive(
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_dpmpp_2s_ancestral(
|
|
|
|
|
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0
|
|
|
|
|
model,
|
|
|
|
|
x,
|
|
|
|
|
sigmas,
|
|
|
|
|
extra_args=None,
|
|
|
|
|
callback=None,
|
|
|
|
|
disable=None,
|
|
|
|
|
eta=1.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
):
|
|
|
|
|
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
sigma_fn = lambda t: t.neg().exp()
|
|
|
|
|
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
|
|
|
@ -708,7 +820,7 @@ def sample_dpmpp_2s_ancestral(
|
|
|
|
|
dt = sigma_down - sigmas[i]
|
|
|
|
|
x = x + d * dt
|
|
|
|
|
else:
|
|
|
|
|
# DPM-Solver-2++(2S)
|
|
|
|
|
# DPM-Solver++(2S)
|
|
|
|
|
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
|
|
|
|
|
r = 1 / 2
|
|
|
|
|
h = t_next - t
|
|
|
|
@ -717,7 +829,75 @@ def sample_dpmpp_2s_ancestral(
|
|
|
|
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
|
|
|
|
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
|
|
|
|
|
# Noise addition
|
|
|
|
|
x = x + torch.randn_like(x) * s_noise * sigma_up
|
|
|
|
|
if sigmas[i + 1] > 0:
|
|
|
|
|
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def sample_dpmpp_sde(
|
|
|
|
|
model,
|
|
|
|
|
x,
|
|
|
|
|
sigmas,
|
|
|
|
|
extra_args=None,
|
|
|
|
|
callback=None,
|
|
|
|
|
disable=None,
|
|
|
|
|
eta=1.0,
|
|
|
|
|
s_noise=1.0,
|
|
|
|
|
noise_sampler=None,
|
|
|
|
|
r=1 / 2,
|
|
|
|
|
):
|
|
|
|
|
"""DPM-Solver++ (stochastic)."""
|
|
|
|
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
|
|
|
|
noise_sampler = (
|
|
|
|
|
BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
|
|
|
|
|
if noise_sampler is None
|
|
|
|
|
else noise_sampler
|
|
|
|
|
)
|
|
|
|
|
extra_args = {} if extra_args is None else extra_args
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
sigma_fn = lambda t: t.neg().exp()
|
|
|
|
|
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
|
if callback is not None:
|
|
|
|
|
callback(
|
|
|
|
|
{
|
|
|
|
|
"x": x,
|
|
|
|
|
"i": i,
|
|
|
|
|
"sigma": sigmas[i],
|
|
|
|
|
"sigma_hat": sigmas[i],
|
|
|
|
|
"denoised": denoised,
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
if sigmas[i + 1] == 0:
|
|
|
|
|
# Euler method
|
|
|
|
|
d = to_d(x, sigmas[i], denoised)
|
|
|
|
|
dt = sigmas[i + 1] - sigmas[i]
|
|
|
|
|
x = x + d * dt
|
|
|
|
|
else:
|
|
|
|
|
# DPM-Solver++
|
|
|
|
|
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
|
|
|
|
h = t_next - t
|
|
|
|
|
s = t + h * r
|
|
|
|
|
fac = 1 / (2 * r)
|
|
|
|
|
|
|
|
|
|
# Step 1
|
|
|
|
|
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
|
|
|
|
s_ = t_fn(sd)
|
|
|
|
|
x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
|
|
|
|
|
x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
|
|
|
|
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
|
|
|
|
|
|
|
|
|
# Step 2
|
|
|
|
|
sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
|
|
|
|
|
t_next_ = t_fn(sd)
|
|
|
|
|
denoised_d = (1 - fac) * denoised + fac * denoised_2
|
|
|
|
|
x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (
|
|
|
|
|
t - t_next_
|
|
|
|
|
).expm1() * denoised_d
|
|
|
|
|
x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -728,7 +908,6 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|
|
|
|
s_in = x.new_ones([x.shape[0]])
|
|
|
|
|
sigma_fn = lambda t: t.neg().exp()
|
|
|
|
|
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
|
|
|
|
|
|
|
|
|
|
old_denoised = None
|
|
|
|
|
|
|
|
|
|
for i in trange(len(sigmas) - 1, disable=disable):
|
|
|
|
|