fix: add workaround for bug in k_diffusion on mps

As documented here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/4558#issuecomment-1310387114

- make image logs more consistent
- note which step the progress images belong to in the filename
This commit is contained in:
Bryce 2022-11-12 19:24:03 -08:00 committed by Bryce Drennan
parent 7fba2972e8
commit 7af1ab66ca
12 changed files with 59 additions and 18 deletions

View File

@ -126,6 +126,8 @@ vendorize_kdiffusion:
# without this most of the k-diffusion samplers didn't work
sed -i '' -e 's#return (x - denoised) / utils.append_dims(sigma, x.ndim)#return (x - denoised) / sigma#g' imaginairy/vendored/k_diffusion/sampling.py
sed -i '' -e 's#x = x + torch.randn_like(x) \* sigma_up#x = x + torch.randn_like(x, device="cpu").to(x.device) \* sigma_up#g' imaginairy/vendored/k_diffusion/sampling.py
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/4558#issuecomment-1310387114
sed -i '' -e 's#t_fn = lambda sigma: sigma.log().neg()#t_fn = lambda sigma: sigma.to('cpu').log().neg().to(x.device)#g' imaginairy/vendored/k_diffusion/sampling.py
make af
vendorize_noodle_soup:

View File

@ -226,6 +226,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
- feature: added `DPM++ 2S a` and `DPM++ 2M` samplers.
- fix: fix bug with `--show-work`
- fix: add workaround for pytorch bug affecting MacOS users using the new `DPM++ 2S a` and `DPM++ 2M` samplers.
- feature: improve progress image logging
**5.0.0**
- feature: 🎉 inpainting support using new inpainting model from RunwayML. It works really well! (Unfortunately it requires a HuggingFace token).

View File

@ -61,10 +61,10 @@ def imagine_image_files(
if output_file_extension not in {"jpg", "png"}:
raise ValueError("Must output a png or jpg")
def _record_step(img, description, step_count, prompt):
def _record_step(img, description, image_count, step_count, prompt):
steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}")
os.makedirs(steps_path, exist_ok=True)
filename = f"{base_count:08}_S{prompt.seed}_step{step_count:04}_{prompt_normalized(description)[:40]}.jpg"
filename = f"{base_count:08}_S{prompt.seed}_{image_count:04}_step{step_count:03}_{prompt_normalized(description)[:40]}.jpg"
destination = os.path.join(steps_path, filename)
draw = ImageDraw.Draw(img)

View File

@ -44,6 +44,12 @@ def log_tensor(t, description=""):
_CURRENT_LOGGING_CONTEXT.log_img(t, description)
def increment_step():
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.step_count += 1
class TimingContext:
def __init__(self, logging_context, description):
self.logging_context = logging_context
@ -62,6 +68,7 @@ class ImageLoggingContext:
self.prompt = prompt
self.model = model
self.step_count = 0
self.image_count = 0
self.img_callback = img_callback
self.img_outdir = img_outdir
self.start_ts = time.perf_counter()
@ -88,7 +95,9 @@ class ImageLoggingContext:
return
img = conditioning_to_img(conditioning)
self.img_callback(img, description, self.step_count, self.prompt)
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_latents(self, latents, description):
from imaginairy.img_utils import model_latents_to_pillow_imgs # noqa
@ -98,23 +107,28 @@ class ImageLoggingContext:
if latents.shape[1] != 4:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
return
self.step_count += 1
try:
shape_str = ",".join(tuple(latents.shape))
except TypeError:
shape_str = str(latents.shape)
description = f"{description}-{shape_str}"
for img in model_latents_to_pillow_imgs(latents):
self.img_callback(img, description, self.step_count, self.prompt)
self.image_count += 1
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_img(self, img, description):
if not self.img_callback:
return
self.step_count += 1
self.image_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
img = img.copy()
self.img_callback(img, description, self.step_count, self.prompt)
self.img_callback(
img, description, self.image_count, self.step_count, self.prompt
)
def log_tensor(self, t, description=""):
if not self.img_callback:

View File

@ -16,6 +16,7 @@ from torch import nn
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_beta_schedule,
@ -761,9 +762,9 @@ class LatentDiffusion(DDPM):
else:
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
log_latent(x_recon, "predicted noise")
return x_recon

View File

@ -145,6 +145,8 @@ def get_noise_prediction(
else:
conditioning_in = torch.cat([neutral_conditioning, positive_conditioning])
# the k-diffusion samplers actually return the denoised predicted latents but things seem
# to work anyway
noise_pred_neutral, noise_pred_positive = denoise_func(
noisy_latent_in, time_encoding_in, conditioning_in
).chunk(2)
@ -154,10 +156,6 @@ def get_noise_prediction(
)
noise_pred = noise_pred_neutral + amplified_noise_pred
log_latent(noise_pred_neutral, "noise_pred_neutral")
log_latent(noise_pred_positive, "noise_pred_positive")
log_latent(noise_pred, "noise_pred")
return noise_pred

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device
@ -94,6 +94,7 @@ class DDIMSampler:
log_latent(noisy_latent, "noisy_latent")
log_latent(predicted_latent, "predicted_latent")
increment_step()
return noisy_latent

View File

@ -1,7 +1,7 @@
# pylama:ignore=W0613
import torch
from imaginairy.log_utils import log_latent
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.samplers.base import CFGDenoiser
from imaginairy.utils import get_device
from imaginairy.vendored.k_diffusion import sampling as k_sampling
@ -111,7 +111,8 @@ class KDiffusionSampler:
def callback(data):
log_latent(data["x"], "noisy_latent")
log_latent(data["denoised"], "noise_pred c")
log_latent(data["denoised"], "predicted_latent")
increment_step()
samples = self.sampler_func(
model=model_wrap_cfg,

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
from tqdm import tqdm
from imaginairy.log_utils import log_latent
from imaginairy.log_utils import increment_step, log_latent
from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like
from imaginairy.samplers.base import NoiseSchedule, get_noise_prediction, mask_blend
from imaginairy.utils import get_device
@ -115,6 +115,7 @@ class PLMSSampler:
log_latent(noisy_latent, "noisy_latent")
log_latent(predicted_latent, "predicted_latent")
increment_step()
return noisy_latent

View File

@ -687,7 +687,7 @@ def sample_dpmpp_2s_ancestral(
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -727,7 +727,8 @@ def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=No
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
t_fn = lambda sigma: sigma.to("cpu").log().neg().to(x.device)
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):

View File

View File

@ -0,0 +1,20 @@
import pytest
import torch
from imaginairy.utils import get_device
@pytest.mark.skipif("mps" not in get_device(), reason="MPS only bug")
@pytest.mark.xfail(reason="MPS only bug")
def test_sigma_bug():
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/4558#issuecomment-1310387114
def t_fn_a(sigma):
return sigma.to(get_device()).log().neg()
def t_fn_b(sigma):
return sigma.to("cpu").log().neg().to(get_device())
sigmas = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], device=get_device())
for i in range(sigmas.size()[0]):
assert t_fn_a(sigmas[i]) == t_fn_b(sigmas[i])