fix: make k-diffusion samplers deterministic

- add test for hashes on mps.  images look same on CUDA but are slightly different.
This commit is contained in:
Bryce 2022-09-14 09:37:45 -07:00
parent b4a3b8c2b3
commit bb665b9eb6
12 changed files with 70 additions and 257 deletions

View File

@ -79,13 +79,14 @@ vendor_openai_clip:
revendorize:
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
#sed -i'' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
#sed -i '' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/evaluation.py
touch imaginairy/vendored/k_diffusion/evaluation.py
rm imaginairy/vendored/k_diffusion/config.py
touch imaginairy/vendored/k_diffusion/config.py
# without this most of the k-diffusion samplers didn't work
sed -i'' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#x = x + torch.randn_like(x) \* sigma_up#x = x + torch.randn_like(x, device="cpu").to(x.device) \* sigma_up#g' imaginairy/vendored/k_diffusion/sampling.py
make af

View File

@ -142,6 +142,7 @@ imagine_image_files(prompts, outdir="./my-art")
- prompt expansion
- Image Generation Features
- ✅ add k-diffusion sampling methods
- why is k-diffusion so slow compared to plms? 2 it/s vs 8 it/s
- upscaling
- ✅ realesrgan
- ldm

View File

@ -17,7 +17,7 @@ from transformers import cached_path
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_log import LatentLoggingContext, log_latent
from imaginairy.safety import is_nsfw, safety_models
from imaginairy.safety import is_nsfw
from imaginairy.samplers.base import get_sampler
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
@ -147,9 +147,6 @@ def imagine(
):
model = load_model(tile_mode=tile_mode)
if not IMAGINAIRY_ALLOW_NSFW:
# needs to be loaded before we set default tensor type to half
safety_models()
# only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda"
if half_mode:
@ -199,17 +196,17 @@ def imagine(
ddim_steps = int(prompt.steps / generation_strength)
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
t_enc = int(generation_strength * ddim_steps)
init_image, w, h = img_path_to_torch_image(prompt.init_image)
init_image = init_image.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image)
)
log_latent(init_latent, "init_latent")
log_latent(init_latent, "init_latent")
# encode (scaled latent)
z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(get_device())
init_latent,
torch.tensor([prompt.steps]).to(get_device()),
)
log_latent(z_enc, "z_enc")
@ -217,7 +214,7 @@ def imagine(
samples = sampler.decode(
z_enc,
c,
t_enc,
prompt.steps,
unconditional_guidance_scale=prompt.prompt_strength,
unconditional_conditioning=uc,
img_callback=_img_callback,

View File

@ -3,6 +3,7 @@ import logging.config
import click
from imaginairy.api import load_model
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
logger = logging.getLogger(__name__)
@ -101,18 +102,7 @@ def configure_logging(level="INFO"):
@click.option(
"--sampler-type",
default="plms",
type=click.Choice(
[
"plms",
"ddim",
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_euler",
"k_euler_a",
"k_heun",
]
),
type=click.Choice(SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use",
)
@click.option("--ddim-eta", default=0.0, type=float)

View File

@ -36,7 +36,12 @@ def uniform_on_device(r1, r2, shape, device):
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
"""
classic DDPM with Gaussian diffusion, in image space
Denoising diffusion probabilistic models
"""
def __init__(
self,
unet_config,

View File

@ -43,7 +43,7 @@ def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=
def find_noise_for_latent(
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
):
import k_diffusion as K
from imaginairy.vendored import k_diffusion as K
x = img_latent
@ -89,6 +89,7 @@ def find_noise_for_latent(
del eps, denoised_uncond, denoised_cond, denoised, d, dt
# collect_and_empty()
# return (x / x.std())
return (x / x.std()) * sigmas[-1]

View File

@ -6,6 +6,17 @@ from imaginairy.samplers.kdiff import KDiffusionSampler
from imaginairy.samplers.plms import PLMSSampler
from imaginairy.utils import get_device
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_euler",
"k_euler_a",
"k_heun",
]
_k_sampler_type_lookup = {
"k_dpm_2": "dpm_2",
"k_dpm_2_a": "dpm_2_ancestral",

View File

@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
class DDIMSampler:
"""
Denoising Diffusion Implicit Models
https://arxiv.org/abs/2010.02502
"""
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
@ -314,7 +320,7 @@ class DDIMSampler:
return x_prev, pred_x0
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
def stochastic_encode(self, init_latent, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
@ -325,10 +331,11 @@ class DDIMSampler:
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
* noise
)
@torch.no_grad()

View File

@ -15,7 +15,9 @@ from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
class PLMSSampler(object):
class PLMSSampler:
"""probabilistic least-mean-squares"""
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model

View File

@ -119,7 +119,7 @@ def sample_euler_ancestral(
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
return x
@ -253,7 +253,7 @@ def sample_dpm_2_ancestral(
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
return x

View File

@ -1,221 +0,0 @@
import math
import torch
from scipy import integrate
from torchdiffeq import odeint
from tqdm.auto import tqdm, trange
from . import utils
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
"""Constructs the noise schedule of Karras et al. (2022)."""
ramp = torch.linspace(0, 1, n)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return append_zero(sigmas).to(device)
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
"""Constructs an exponential noise schedule."""
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
return append_zero(sigmas)
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
return append_zero(sigmas)
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)
def get_ancestral_step(sigma_from, sigma_to):
"""Calculates the noise level (sigma_down) to step down to and the amount
of noise to add (sigma_up) when doing an ancestral sampling step."""
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
return sigma_down, sigma_up
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
return x
@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with Euler method steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
return x
@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigma_hat
dt_2 = sigmas[i + 1] - sigma_hat
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
return x
@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
dt_1 = sigma_mid - sigmas[i]
dt_2 = sigma_down - sigmas[i]
x_2 = x + d * dt_1
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
return x
def linear_multistep_coeff(order, t, i, j):
if order - 1 > i:
raise ValueError(f'Order {order} too high for step {i}')
def fn(tau):
prod = 1.
for k in range(order):
if j == k:
continue
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
return prod
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigmas_cpu = sigmas.detach().cpu().numpy()
ds = []
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
if len(ds) > order:
ds.pop(0)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
return x
@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
v = torch.randint_like(x, 2) * 2 - 1
fevals = 0
def ode_fn(sigma, x):
nonlocal fevals
with torch.enable_grad():
x = x[0].detach().requires_grad_()
denoised = model(x, sigma * s_in, **extra_args)
d = to_d(x, sigma, denoised)
fevals += 1
grad = torch.autograd.grad((d * v).sum(), x)[0]
d_ll = (v * grad).flatten(1).sum(1)
return d.detach(), d_ll
x_min = x, x.new_zeros([x.shape[0]])
t = x.new_tensor([sigma_min, sigma_max])
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
latent, delta_ll = sol[0][-1], sol[1][-1]
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
return ll_prior + delta_ll, {'fevals': fevals}

View File

@ -1,23 +1,42 @@
import pytest
from imaginairy.api import imagine, imagine_image_files
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import get_device
from . import TESTS_FOLDER
mps_sampler_type_test_cases = {
("plms", "3f211329796277a1870378288769fcde"),
("ddim", "70dbf2acce2c052e4e7f37412ae0366e"),
("k_lms", "3585c10c8f27bf091c15e761dca4d578"),
("k_dpm_2", "29b07125c9879540f8efac317ae33aea"),
("k_dpm_2_a", "4fd6767980444ca72e97cba2d0491eb4"),
("k_euler", "50609b279cff756db42ab9d2c85328ed"),
("k_euler_a", "ae7ac199c10f303e5ebd675109e59b23"),
("k_heun", "3668fe66770538337ac8c0b7ac210892"),
}
def test_imagine():
@pytest.mark.skipif(get_device() != "mps", reason="mps hashes")
@pytest.mark.parametrize("sampler_type,expected_md5", mps_sampler_type_test_cases)
def test_imagine(sampler_type, expected_md5):
prompt_text = "a scenic landscape"
prompt = ImaginePrompt(
"a scenic landscape", width=512, height=256, steps=20, seed=1
prompt_text, width=512, height=256, steps=10, seed=1, sampler_type=sampler_type
)
result = next(imagine(prompt))
assert result.md5() == "4c5957c498881d365cfcf13014812af0"
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
result.img.save(
f"{TESTS_FOLDER}/test_output/sampler_type_{sampler_type.upper()}.jpg"
)
assert result.md5() == expected_md5
def test_img_to_img():
prompt = ImaginePrompt(
"a photo of a beach",
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
init_image_strength=0.5,
init_image_strength=0.8,
width=512,
height=512,
steps=50,