mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
fix: make k-diffusion samplers deterministic
- add test for hashes on mps. images look same on CUDA but are slightly different.
This commit is contained in:
parent
b4a3b8c2b3
commit
bb665b9eb6
5
Makefile
5
Makefile
@ -79,13 +79,14 @@ vendor_openai_clip:
|
||||
revendorize:
|
||||
make vendorize REPO=git@github.com:openai/CLIP.git PKG=clip COMMIT=d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
|
||||
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=1a0703dfb7d24d8806267c3e7ccc4caf67fd1331
|
||||
#sed -i'' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
|
||||
#sed -i '' -e 's/^import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
|
||||
rm imaginairy/vendored/k_diffusion/evaluation.py
|
||||
touch imaginairy/vendored/k_diffusion/evaluation.py
|
||||
rm imaginairy/vendored/k_diffusion/config.py
|
||||
touch imaginairy/vendored/k_diffusion/config.py
|
||||
# without this most of the k-diffusion samplers didn't work
|
||||
sed -i'' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
|
||||
sed -i '' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
|
||||
sed -i '' -e 's#x = x + torch.randn_like(x) \* sigma_up#x = x + torch.randn_like(x, device="cpu").to(x.device) \* sigma_up#g' imaginairy/vendored/k_diffusion/sampling.py
|
||||
make af
|
||||
|
||||
|
||||
|
@ -142,6 +142,7 @@ imagine_image_files(prompts, outdir="./my-art")
|
||||
- prompt expansion
|
||||
- Image Generation Features
|
||||
- ✅ add k-diffusion sampling methods
|
||||
- why is k-diffusion so slow compared to plms? 2 it/s vs 8 it/s
|
||||
- upscaling
|
||||
- ✅ realesrgan
|
||||
- ldm
|
||||
|
@ -17,7 +17,7 @@ from transformers import cached_path
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.img_log import LatentLoggingContext, log_latent
|
||||
from imaginairy.safety import is_nsfw, safety_models
|
||||
from imaginairy.safety import is_nsfw
|
||||
from imaginairy.samplers.base import get_sampler
|
||||
from imaginairy.schema import ImaginePrompt, ImagineResult
|
||||
from imaginairy.utils import (
|
||||
@ -147,9 +147,6 @@ def imagine(
|
||||
):
|
||||
model = load_model(tile_mode=tile_mode)
|
||||
|
||||
if not IMAGINAIRY_ALLOW_NSFW:
|
||||
# needs to be loaded before we set default tensor type to half
|
||||
safety_models()
|
||||
# only run half-mode on cuda. run it by default
|
||||
half_mode = half_mode is None and get_device() == "cuda"
|
||||
if half_mode:
|
||||
@ -199,17 +196,17 @@ def imagine(
|
||||
ddim_steps = int(prompt.steps / generation_strength)
|
||||
sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta)
|
||||
|
||||
t_enc = int(generation_strength * ddim_steps)
|
||||
init_image, w, h = img_path_to_torch_image(prompt.init_image)
|
||||
init_image = init_image.to(get_device())
|
||||
init_latent = model.get_first_stage_encoding(
|
||||
model.encode_first_stage(init_image)
|
||||
)
|
||||
log_latent(init_latent, "init_latent")
|
||||
|
||||
log_latent(init_latent, "init_latent")
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(
|
||||
init_latent, torch.tensor([t_enc]).to(get_device())
|
||||
init_latent,
|
||||
torch.tensor([prompt.steps]).to(get_device()),
|
||||
)
|
||||
log_latent(z_enc, "z_enc")
|
||||
|
||||
@ -217,7 +214,7 @@ def imagine(
|
||||
samples = sampler.decode(
|
||||
z_enc,
|
||||
c,
|
||||
t_enc,
|
||||
prompt.steps,
|
||||
unconditional_guidance_scale=prompt.prompt_strength,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=_img_callback,
|
||||
|
@ -3,6 +3,7 @@ import logging.config
|
||||
import click
|
||||
|
||||
from imaginairy.api import load_model
|
||||
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -101,18 +102,7 @@ def configure_logging(level="INFO"):
|
||||
@click.option(
|
||||
"--sampler-type",
|
||||
default="plms",
|
||||
type=click.Choice(
|
||||
[
|
||||
"plms",
|
||||
"ddim",
|
||||
"k_lms",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_a",
|
||||
"k_euler",
|
||||
"k_euler_a",
|
||||
"k_heun",
|
||||
]
|
||||
),
|
||||
type=click.Choice(SAMPLER_TYPE_OPTIONS),
|
||||
help="What sampling strategy to use",
|
||||
)
|
||||
@click.option("--ddim-eta", default=0.0, type=float)
|
||||
|
@ -36,7 +36,12 @@ def uniform_on_device(r1, r2, shape, device):
|
||||
|
||||
|
||||
class DDPM(pl.LightningModule):
|
||||
# classic DDPM with Gaussian diffusion, in image space
|
||||
"""
|
||||
classic DDPM with Gaussian diffusion, in image space
|
||||
|
||||
Denoising diffusion probabilistic models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet_config,
|
||||
|
@ -43,7 +43,7 @@ def find_noise_for_image(model, pil_img, prompt, steps=50, cond_scale=1.0, half=
|
||||
def find_noise_for_latent(
|
||||
model, img_latent, prompt, steps=50, cond_scale=1.0, half=True
|
||||
):
|
||||
import k_diffusion as K
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
|
||||
x = img_latent
|
||||
|
||||
@ -89,6 +89,7 @@ def find_noise_for_latent(
|
||||
del eps, denoised_uncond, denoised_cond, denoised, d, dt
|
||||
# collect_and_empty()
|
||||
|
||||
# return (x / x.std())
|
||||
return (x / x.std()) * sigmas[-1]
|
||||
|
||||
|
||||
|
@ -6,6 +6,17 @@ from imaginairy.samplers.kdiff import KDiffusionSampler
|
||||
from imaginairy.samplers.plms import PLMSSampler
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
SAMPLER_TYPE_OPTIONS = [
|
||||
"plms",
|
||||
"ddim",
|
||||
"k_lms",
|
||||
"k_dpm_2",
|
||||
"k_dpm_2_a",
|
||||
"k_euler",
|
||||
"k_euler_a",
|
||||
"k_heun",
|
||||
]
|
||||
|
||||
_k_sampler_type_lookup = {
|
||||
"k_dpm_2": "dpm_2",
|
||||
"k_dpm_2_a": "dpm_2_ancestral",
|
||||
|
@ -18,6 +18,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DDIMSampler:
|
||||
"""
|
||||
Denoising Diffusion Implicit Models
|
||||
|
||||
https://arxiv.org/abs/2010.02502
|
||||
"""
|
||||
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
@ -314,7 +320,7 @@ class DDIMSampler:
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
def stochastic_encode(self, init_latent, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
@ -325,10 +331,11 @@ class DDIMSampler:
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -15,7 +15,9 @@ from imaginairy.utils import get_device
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
class PLMSSampler:
|
||||
"""probabilistic least-mean-squares"""
|
||||
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
@ -119,7 +119,7 @@ def sample_euler_ancestral(
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@ -253,7 +253,7 @@ def sample_dpm_2_ancestral(
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
x = x + torch.randn_like(x, device="cpu").to(x.device) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
|
@ -1,221 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from scipy import integrate
|
||||
from torchdiffeq import odeint
|
||||
from tqdm.auto import tqdm, trange
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
ramp = torch.linspace(0, 1, n)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return append_zero(sigmas).to(device)
|
||||
|
||||
|
||||
def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
|
||||
"""Constructs an exponential noise schedule."""
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
|
||||
"""Constructs a continuous VP noise schedule."""
|
||||
t = torch.linspace(1, eps_s, n, device=device)
|
||||
sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
|
||||
return append_zero(sigmas)
|
||||
|
||||
|
||||
def to_d(x, sigma, denoised):
|
||||
"""Converts a denoiser output to a Karras ODE derivative."""
|
||||
return (x - denoised) / utils.append_dims(sigma, x.ndim)
|
||||
|
||||
|
||||
def get_ancestral_step(sigma_from, sigma_to):
|
||||
"""Calculates the noise level (sigma_down) to step down to and the amount
|
||||
of noise to add (sigma_up) when doing an ancestral sampling step."""
|
||||
sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
|
||||
sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
|
||||
return sigma_down, sigma_up
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with Euler method steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
dt = sigmas[i + 1] - sigma_hat
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
x = x + d * dt
|
||||
else:
|
||||
# Heun's method
|
||||
x_2 = x + d * dt
|
||||
denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
|
||||
d_prime = (d + d_2) / 2
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
|
||||
eps = torch.randn_like(x) * s_noise
|
||||
sigma_hat = sigmas[i] * (gamma + 1)
|
||||
if gamma > 0:
|
||||
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
|
||||
denoised = model(x, sigma_hat * s_in, **extra_args)
|
||||
d = to_d(x, sigma_hat, denoised)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigma_hat
|
||||
dt_2 = sigmas[i + 1] - sigma_hat
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""Ancestral sampling with DPM-Solver inspired second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
|
||||
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
|
||||
dt_1 = sigma_mid - sigmas[i]
|
||||
dt_2 = sigma_down - sigmas[i]
|
||||
x_2 = x + d * dt_1
|
||||
denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
|
||||
d_2 = to_d(x_2, sigma_mid, denoised_2)
|
||||
x = x + d_2 * dt_2
|
||||
x = x + torch.randn_like(x) * sigma_up
|
||||
return x
|
||||
|
||||
|
||||
def linear_multistep_coeff(order, t, i, j):
|
||||
if order - 1 > i:
|
||||
raise ValueError(f'Order {order} too high for step {i}')
|
||||
def fn(tau):
|
||||
prod = 1.
|
||||
for k in range(order):
|
||||
if j == k:
|
||||
continue
|
||||
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
|
||||
return prod
|
||||
return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigmas_cpu = sigmas.detach().cpu().numpy()
|
||||
ds = []
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
ds.append(d)
|
||||
if len(ds) > order:
|
||||
ds.pop(0)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
cur_order = min(i + 1, order)
|
||||
coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
|
||||
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
v = torch.randint_like(x, 2) * 2 - 1
|
||||
fevals = 0
|
||||
def ode_fn(sigma, x):
|
||||
nonlocal fevals
|
||||
with torch.enable_grad():
|
||||
x = x[0].detach().requires_grad_()
|
||||
denoised = model(x, sigma * s_in, **extra_args)
|
||||
d = to_d(x, sigma, denoised)
|
||||
fevals += 1
|
||||
grad = torch.autograd.grad((d * v).sum(), x)[0]
|
||||
d_ll = (v * grad).flatten(1).sum(1)
|
||||
return d.detach(), d_ll
|
||||
x_min = x, x.new_zeros([x.shape[0]])
|
||||
t = x.new_tensor([sigma_min, sigma_max])
|
||||
sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
|
||||
latent, delta_ll = sol[0][-1], sol[1][-1]
|
||||
ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
|
||||
return ll_prior + delta_ll, {'fevals': fevals}
|
@ -1,23 +1,42 @@
|
||||
import pytest
|
||||
|
||||
from imaginairy.api import imagine, imagine_image_files
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
from . import TESTS_FOLDER
|
||||
|
||||
mps_sampler_type_test_cases = {
|
||||
("plms", "3f211329796277a1870378288769fcde"),
|
||||
("ddim", "70dbf2acce2c052e4e7f37412ae0366e"),
|
||||
("k_lms", "3585c10c8f27bf091c15e761dca4d578"),
|
||||
("k_dpm_2", "29b07125c9879540f8efac317ae33aea"),
|
||||
("k_dpm_2_a", "4fd6767980444ca72e97cba2d0491eb4"),
|
||||
("k_euler", "50609b279cff756db42ab9d2c85328ed"),
|
||||
("k_euler_a", "ae7ac199c10f303e5ebd675109e59b23"),
|
||||
("k_heun", "3668fe66770538337ac8c0b7ac210892"),
|
||||
}
|
||||
|
||||
def test_imagine():
|
||||
|
||||
@pytest.mark.skipif(get_device() != "mps", reason="mps hashes")
|
||||
@pytest.mark.parametrize("sampler_type,expected_md5", mps_sampler_type_test_cases)
|
||||
def test_imagine(sampler_type, expected_md5):
|
||||
prompt_text = "a scenic landscape"
|
||||
prompt = ImaginePrompt(
|
||||
"a scenic landscape", width=512, height=256, steps=20, seed=1
|
||||
prompt_text, width=512, height=256, steps=10, seed=1, sampler_type=sampler_type
|
||||
)
|
||||
result = next(imagine(prompt))
|
||||
assert result.md5() == "4c5957c498881d365cfcf13014812af0"
|
||||
result.img.save(f"{TESTS_FOLDER}/test_output/scenic_landscape.png")
|
||||
result.img.save(
|
||||
f"{TESTS_FOLDER}/test_output/sampler_type_{sampler_type.upper()}.jpg"
|
||||
)
|
||||
assert result.md5() == expected_md5
|
||||
|
||||
|
||||
def test_img_to_img():
|
||||
prompt = ImaginePrompt(
|
||||
"a photo of a beach",
|
||||
init_image=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg",
|
||||
init_image_strength=0.5,
|
||||
init_image_strength=0.8,
|
||||
width=512,
|
||||
height=512,
|
||||
steps=50,
|
||||
|
Loading…
Reference in New Issue
Block a user