feature: added `DPM++ 2S a` and `DPM++ 2M` samplers

-fix: fix bug with `--show-work`
pull/86/head
Bryce 2 years ago committed by Bryce Drennan
parent 17399e7702
commit 7fba2972e8

@ -116,7 +116,7 @@ vendorize_kdiffusion:
rm -rf ./imaginairy/vendored/k_diffusion
rm -rf ./downloads/k_diffusion
# version 0.0.9
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=f4e99857772fc3a126ba886aadf795a332774878
make vendorize REPO=git@github.com:crowsonkb/k-diffusion.git PKG=k_diffusion COMMIT=60e5042ca0da89c14d1dd59d73883280f8fce991
#sed -i '' -e 's/import\sclip/from\simaginairy.vendored\simport\sclip/g' imaginairy/vendored/k_diffusion/evaluation.py
mv ./downloads/k_diffusion/LICENSE ./imaginairy/vendored/k_diffusion/
rm imaginairy/vendored/k_diffusion/evaluation.py

@ -224,6 +224,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- feature: added `DPM++ 2S a` and `DPM++ 2M` samplers.
- fix: fix bug with `--show-work`
**5.0.0**
- feature: 🎉 inpainting support using new inpainting model from RunwayML. It works really well! (Unfortunately it requires a HuggingFace token).
By default, inpainting model will automatically be used for any image-masking task
@ -447,6 +450,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://stablecog.com/
## Further Reading
- [Prompt Engineering Handbook](https://openart.ai/promptbook)
- Differences between samplers
- https://www.reddit.com/r/StableDiffusion/comments/xbeyw3/can_anyone_offer_a_little_guidance_on_the/
- https://www.reddit.com/r/bigsleep/comments/xb5cat/wiskkeys_lists_of_texttoimage_systems_and_related/

@ -136,26 +136,26 @@ def imagine(
prompt=prompt,
model=model,
img_callback=img_callback,
):
) as lc:
seed_everything(prompt.seed)
model.tile_mode(prompt.tile_mode)
neutral_conditioning = None
if prompt.prompt_strength != 1.0:
neutral_conditioning = model.get_learned_conditioning(
batch_size * [""]
)
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning
else:
total_weight = sum(wp.weight for wp in prompt.prompts)
positive_conditioning = sum(
model.get_learned_conditioning(wp.text)
* (wp.weight / total_weight)
for wp in prompt.prompts
)
log_conditioning(positive_conditioning, "positive conditioning")
with lc.timing("conditioning"):
neutral_conditioning = None
if prompt.prompt_strength != 1.0:
neutral_conditioning = model.get_learned_conditioning(
batch_size * [""]
)
log_conditioning(neutral_conditioning, "neutral conditioning")
if prompt.conditioning is not None:
positive_conditioning = prompt.conditioning
else:
total_weight = sum(wp.weight for wp in prompt.prompts)
positive_conditioning = sum(
model.get_learned_conditioning(wp.text)
* (wp.weight / total_weight)
for wp in prompt.prompts
)
log_conditioning(positive_conditioning, "positive conditioning")
shape = [
batch_size,
@ -298,18 +298,19 @@ def imagine(
"c_concat": c_cat,
"c_crossattn": [neutral_conditioning],
}
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask,
orig_latent=init_latent,
shape=shape,
batch_size=1,
)
with lc.timing("sampling"):
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask,
orig_latent=init_latent,
shape=shape,
batch_size=1,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
@ -337,10 +338,11 @@ def imagine(
caption = generate_caption(img)
logger.info(f"Generated caption: {caption}")
safety_score = create_safety_score(
img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
with lc.timing("safety-filter"):
safety_score = create_safety_score(
img,
safety_mode=IMAGINAIRY_SAFETY_MODE,
)
if not safety_score.is_filtered:
if prompt.fix_faces:
logger.info("Fixing 😊 's in 🖼 using CodeFormer...")
@ -381,7 +383,7 @@ def imagine(
)
log_img(rebuilt_orig_img, "reconstituted original")
yield ImagineResult(
result = ImagineResult(
img=img,
prompt=prompt,
upscaled_img=upscaled_img,
@ -390,7 +392,10 @@ def imagine(
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
timings=lc.get_timings(),
)
logger.info(f"Image Generated. Timings: {result.timings_str()}")
yield result
def prompt_normalized(prompt):

@ -1,6 +1,7 @@
import logging
import logging.config
import re
import time
import warnings
import torch
@ -25,6 +26,9 @@ def log_latent(latents, description):
if _CURRENT_LOGGING_CONTEXT is None:
return
if latents is None:
return
_CURRENT_LOGGING_CONTEXT.log_latents(latents, description)
@ -40,6 +44,19 @@ def log_tensor(t, description=""):
_CURRENT_LOGGING_CONTEXT.log_img(t, description)
class TimingContext:
def __init__(self, logging_context, description):
self.logging_context = logging_context
self.description = description
self.start_time = None
def __enter__(self):
self.start_time = time.time()
def __exit__(self, exc_type, exc_value, traceback):
self.logging_context.timings[self.description] = time.time() - self.start_time
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
self.prompt = prompt
@ -47,6 +64,8 @@ class ImageLoggingContext:
self.step_count = 0
self.img_callback = img_callback
self.img_outdir = img_outdir
self.start_ts = time.perf_counter()
self.timings = {}
def __enter__(self):
global _CURRENT_LOGGING_CONTEXT # noqa
@ -57,6 +76,13 @@ class ImageLoggingContext:
global _CURRENT_LOGGING_CONTEXT # noqa
_CURRENT_LOGGING_CONTEXT = None
def timing(self, description):
return TimingContext(self, description)
def get_timings(self):
self.timings["total"] = time.perf_counter() - self.start_ts
return self.timings
def log_conditioning(self, conditioning, description):
if not self.img_callback:
return

@ -23,6 +23,8 @@ SAMPLER_TYPE_OPTIONS = [
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_dpmpp_2m",
"k_dpmpp_2s_a",
"k_euler",
"k_euler_a",
"k_heun",
@ -33,6 +35,8 @@ _k_sampler_type_lookup = {
"k_dpm_adaptive": "dpm_adaptive",
"k_dpm_2": "dpm_2",
"k_dpm_2_a": "dpm_2_ancestral",
"k_dpmpp_2m": "dpmpp_2m",
"k_dpmpp_2s_a": "dpmpp_2s_ancestral",
"k_euler": "euler",
"k_euler_a": "euler_ancestral",
"k_heun": "heun",

@ -51,6 +51,8 @@ class KDiffusionSampler:
"dpm_adaptive": sample_dpm_adaptive,
"dpm_2": k_sampling.sample_dpm_2,
"dpm_2_ancestral": k_sampling.sample_dpm_2_ancestral,
"dpmpp_2m": k_sampling.sample_dpmpp_2m,
"dpmpp_2s_ancestral": k_sampling.sample_dpmpp_2s_ancestral,
"euler": k_sampling.sample_euler,
"euler_ancestral": k_sampling.sample_euler_ancestral,
"heun": k_sampling.sample_heun,

@ -201,6 +201,7 @@ class ImagineResult:
modified_original=None,
mask_binary=None,
mask_grayscale=None,
timings=None,
):
self.prompt = prompt
@ -218,6 +219,8 @@ class ImagineResult:
if mask_grayscale:
self.images["mask_grayscale"] = mask_grayscale
self.timings = timings
# for backward compat
self.img = img
self.upscaled_img = upscaled_img
@ -236,6 +239,11 @@ class ImagineResult:
"prompt": self.prompt.as_dict(),
}
def timings_str(self):
if not self.timings:
return ""
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items())
def _exif(self):
exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()

@ -0,0 +1,19 @@
Copyright (c) 2022 Katherine Crowson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

@ -677,3 +677,79 @@ def sample_dpm_adaptive(
if return_info:
return x, info
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral(
model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0
):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
if sigma_down == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver-2++(2S)
t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
r = 1 / 2
h = t_next - t
s = t + r * h
x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
# Noise addition
x = x + torch.randn_like(x) * s_noise * sigma_up
return x
@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
"""DPM-Solver++(2M)."""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
old_denoised = None
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback(
{
"x": x,
"i": i,
"sigma": sigmas[i],
"sigma_hat": sigmas[i],
"denoised": denoised,
}
)
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
h = t_next - t
if old_denoised is None or sigmas[i + 1] == 0:
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
else:
h_last = t - t_fn(sigmas[i - 1])
r = h_last / h
denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
old_denoised = denoised
return x

@ -1 +1 @@
f4e99857772fc3a126ba886aadf795a332774878
60e5042ca0da89c14d1dd59d73883280f8fce991

Binary file not shown.

After

Width:  |  Height:  |  Size: 568 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 577 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 377 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 264 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 249 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 240 KiB

Loading…
Cancel
Save