feature: k-diff sampler img2img and masking

pull/63/head
Bryce 2 years ago committed by Bryce Drennan
parent 740870ad8e
commit 4ba1965db8

@ -58,7 +58,7 @@ operators also work. When writing strength modifies know that pixel values are
```bash
>> imagine \
--init-image fruit-bowl.jpg \
--mask-prompt "fruit OR fruit stem{*1.5}" \
--mask-prompt "fruit OR fruit stem{*6}" \
--mask-mode replace \
--mask-modify-original \
--init-image-strength .1 \
@ -213,6 +213,12 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog
**3.1.0**
- feature: img2img/inpainting supported on all samplers
- refactor: consolidates img2img/txt2img code. consolidates schedules. consolidates masking
- ci: minor logging improvements
**3.0.1**
- fix: k-samplers were broken
@ -312,6 +318,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- Development Environment
- ✅ add tests
- ✅ set up ci (test/lint/format)
- ✅ unified pipeline (txt2img & img2img combined)
- setup parallel testing
- add docs
- remove yaml config
@ -329,7 +336,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://colab.research.google.com/github/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch/blob/main/notebooks/demo.ipynb#scrollTo=wt_j3uXZGFAS
- negative prompting
- some syntax to allow it in a text string
- images as actual prompts instead of just init images
- images as actual prompts instead of just init images. is this the same as textual inversion?
- requires model fine-tuning since SD1.4 expects 77x768 text encoding input
- https://twitter.com/Buntworthy/status/1566744186153484288
- https://github.com/justinpinkney/stable-diffusion
@ -374,7 +381,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- 🚫 CPU support. While the code does actually work on some CPUs, the generation takes so long that I don't think it's
worth the effort to support this feature
- ✅ img2img for plms
- img2img for kdiff functions
- img2img for kdiff functions
- Other
- Enhancement pipelines
- text-to-3d https://dreamfusionpaper.github.io/
@ -400,6 +407,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://www.reddit.com/r/StableDiffusion/comments/xbrrgt/a_rundown_of_twenty_new_methodsoptions_added_to/
- ✅ deploy to pypi
- find similar images https://knn5.laion.ai/?back=https%3A%2F%2Fknn5.laion.ai%2F&index=laion5B&useMclip=false
- https://github.com/vicgalle/stable-diffusion-aesthetic-gradients
## Noteable Stable Diffusion Implementations
- https://github.com/ahrm/UnstableFusion
@ -410,10 +418,17 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://github.com/lkwq007/stablediffusion-infinity
- https://github.com/lstein/stable-diffusion
- https://github.com/parlance-zz/g-diffuser-lib
- https://github.com/hafriedlander/idea2art
## Online Stable Diffusion Services
- https://stablecog.com/
## Further Reading
- Differences between samplers
- https://www.reddit.com/r/StableDiffusion/comments/xbeyw3/can_anyone_offer_a_little_guidance_on_the/
- https://www.reddit.com/r/bigsleep/comments/xb5cat/wiskkeys_lists_of_texttoimage_systems_and_related/
- https://huggingface.co/blog/annotated-diffusion
- https://github.com/jessevig/bertviz
- https://www.youtube.com/watch?v=5pIQFQZsNe8
- https://jalammar.github.io/illustrated-transformer/
- https://huggingface.co/blog/assets/78_annotated-diffusion/unet_architecture.jpg

@ -25,7 +25,7 @@ from imaginairy.log_utils import (
log_latent,
)
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers.base import NoiseSchedule, get_sampler
from imaginairy.samplers.base import NoiseSchedule, get_sampler, noise_an_image
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_group_norm,
@ -38,7 +38,6 @@ from imaginairy.utils import (
LIB_PATH = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
# leave undocumented. I'd ask that no one publicize this flag. Just want a
# slight barrier to entry. Please don't use this is any way that's gonna cause
# the media or politicians to freak out about AI...
@ -207,19 +206,10 @@ def imagine(
prompt.height // downsampling_factor,
prompt.width // downsampling_factor,
]
if prompt.init_image and prompt.sampler_type not in ("ddim", "plms"):
sampler_type = "plms"
logger.info("Sampler type switched to plms for img2img")
else:
sampler_type = prompt.sampler_type
sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig, mask_grayscale = (
None,
None,
None,
None,
)
sampler = get_sampler(prompt.sampler_type, model)
mask = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = init_latent_noised = None
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength)
@ -271,6 +261,7 @@ def imagine(
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
)
shape = init_latent.shape
log_latent(init_latent, "init_latent")
# encode (scaled latent)
@ -289,36 +280,27 @@ def imagine(
# (or setting steps=1000)
init_latent_noised = noise
else:
init_latent_noised = sampler.noise_an_image(
init_latent_noised = noise_an_image(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
schedule=schedule,
noise=noise,
)
log_latent(init_latent_noised, "init_latent_noised")
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask,
orig_latent=init_latent,
shape=shape,
batch_size=1,
)
else:
samples = sampler.sample(
num_steps=prompt.steps,
neutral_conditioning=neutral_conditioning,
positive_conditioning=positive_conditioning,
guidance_scale=prompt.prompt_strength,
batch_size=1,
shape=shape,
)
log_latent(init_latent_noised, "init_latent_noised")
samples = sampler.sample(
num_steps=prompt.steps,
initial_latent=init_latent_noised,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=prompt.prompt_strength,
t_start=t_enc,
mask=mask,
orig_latent=init_latent,
shape=shape,
batch_size=1,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

@ -72,6 +72,7 @@ logger = logging.getLogger(__name__)
)
@click.option(
"--sampler-type",
"--sampler",
default="plms",
type=click.Choice(SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use",

@ -34,6 +34,12 @@ def log_img(img, description):
_CURRENT_LOGGING_CONTEXT.log_img(img, description)
def log_tensor(t, description=""):
if _CURRENT_LOGGING_CONTEXT is None:
return
_CURRENT_LOGGING_CONTEXT.log_img(t, description)
class ImageLoggingContext:
def __init__(self, prompt, model, img_callback=None, img_outdir=None):
self.prompt = prompt
@ -67,7 +73,11 @@ class ImageLoggingContext:
# logger.info(f"Didn't save tensor of shape {samples.shape} for {description}")
return
self.step_count += 1
description = f"{description} - {latents.shape}"
try:
shape_str = ",".join(tuple(latents.shape))
except TypeError:
shape_str = str(latents.shape)
description = f"{description}-{shape_str}"
for img in model_latents_to_pillow_imgs(latents):
self.img_callback(img, description, self.step_count, self.prompt)
@ -80,6 +90,16 @@ class ImageLoggingContext:
img = img.copy()
self.img_callback(img, description, self.step_count, self.prompt)
def log_tensor(self, t, description=""):
if not self.img_callback:
return
if len(t.shape) == 2:
self.log_img(t, description)
def log_indexed_graph_of_tensor(self):
pass
# def img_callback(self, img, description, step_count, prompt):
# steps_path = os.path.join(self.img_outdir, "steps", f"{self.file_num:08}_S{prompt.seed}")
# os.makedirs(steps_path, exist_ok=True)

@ -2,7 +2,7 @@ import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from torch import einsum, nn
from imaginairy.modules.diffusion.util import checkpoint
@ -140,6 +140,11 @@ class CrossAttention(nn.Module):
)
def forward(self, x, context=None, mask=None):
# from imaginairy.api import _global_mask_hack
#
# if mask is None and _global_mask_hack is not None:
# mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda":
return self.forward_cuda(x, context=context, mask=mask)
@ -154,11 +159,13 @@ class CrossAttention(nn.Module):
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if mask is not None:
mask = rearrange(mask, "b ... -> b (...)")
_max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~mask, _max_neg_value)
# if mask is not None:
# if sim.shape[2] == 320 and False:
# mask = [mask] * 2
# mask = rearrange(mask, "b ... -> b (...)")
# _max_neg_value = -torch.finfo(sim.dtype).max
# mask = repeat(mask, "b j -> (b h) () j", h=h)
# sim.masked_fill_(~mask, _max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)

@ -16,6 +16,7 @@ import torch
from einops import repeat as e_repeat
from torch import nn
from imaginairy.log_utils import log_tensor
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
@ -207,6 +208,7 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
)
else:
embedding = e_repeat(timesteps, "b -> b d", d=dim)
log_tensor(embedding, "timestep_embedding")
return embedding

@ -1,15 +1,20 @@
# pylama:ignore=W0613
import logging
import numpy as np
import torch
from torch import nn
from imaginairy.log_utils import log_latent
from imaginairy.log_utils import log_img, log_latent
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
)
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",
@ -55,13 +60,37 @@ class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.device = get_device()
def forward(self, x, sigma, uncond, cond, cond_scale, mask=None, orig_latent=None):
def forward(
self,
x,
sigma,
uncond,
cond,
cond_scale,
mask=None,
mask_noise=None,
orig_latent=None,
):
def _wrapper(noisy_latent_in, time_encoding_in, conditioning_in):
return self.inner_model(
noisy_latent_in, time_encoding_in, cond=conditioning_in
)
if mask is not None:
assert orig_latent is not None
t = self.inner_model.sigma_to_t(sigma, quantize=True)
big_sigma = max(sigma, 1)
x = mask_blend(
noisy_latent=x,
orig_latent=orig_latent * big_sigma,
mask=mask,
mask_noise=mask_noise * big_sigma,
ts=t,
model=self.inner_model.inner_model,
)
noise_pred = get_noise_prediction(
denoise_func=_wrapper,
noisy_latent=x,
@ -71,11 +100,6 @@ class CFGDenoiser(nn.Module):
signal_amplification=cond_scale,
)
if mask is not None:
assert orig_latent is not None
mask_inv = 1.0 - mask
noise_pred = (orig_latent * mask) + (mask_inv * noise_pred)
return noise_pred
@ -122,19 +146,27 @@ def mask_blend(noisy_latent, orig_latent, mask, mask_noise, ts, model):
ts is a decreasing value between 1000 and 1
"""
assert orig_latent is not None
log_latent(orig_latent, "orig_latent")
noised_orig_latent = model.q_sample(orig_latent, ts, mask_noise)
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
hint_strength = 1
# if we're in the first 10% of the steps then don't fully noise the parts
# of the image we're not changing so that the algorithm can learn from the context
if ts > 900:
xdec_orig_with_hints = (
hinted_orig_latent = (
noised_orig_latent * (1 - hint_strength) + orig_latent * hint_strength
)
log_latent(hinted_orig_latent, f"hinted_orig_latent {ts}")
else:
xdec_orig_with_hints = noised_orig_latent
noisy_latent = xdec_orig_with_hints * mask + (1.0 - mask) * noisy_latent
hinted_orig_latent = noised_orig_latent
log_img(mask, f"mask {ts}")
# logger.info(mask.shape)
hinted_orig_latent_masked = hinted_orig_latent * mask
log_latent(hinted_orig_latent_masked, f"hinted_orig_latent_masked {ts}")
noisy_latent_masked = (1.0 - mask) * noisy_latent
log_latent(noisy_latent_masked, f"noisy_latent_masked {ts}")
noisy_latent = hinted_orig_latent_masked + noisy_latent_masked
log_latent(noisy_latent, f"mask-blended noisy_latent {ts}")
return noisy_latent
@ -181,3 +213,20 @@ class NoiseSchedule:
self.ddim_sqrt_one_minus_alphas = (
np.sqrt(1.0 - ddim_alphas).to(torch.float32).to(device)
)
@torch.no_grad()
def noise_an_image(init_latent, t, schedule, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
t = t.clamp(0, 1000)
sqrt_alphas_cumprod = torch.sqrt(schedule.ddim_alphas)
sqrt_one_minus_alphas_cumprod = schedule.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
return (
extract_into_tensor(sqrt_alphas_cumprod, t, init_latent.shape) * init_latent
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, init_latent.shape)
* noise
)

@ -43,16 +43,30 @@ class KDiffusionSampler:
initial_latent = torch.randn(shape, device="cpu").to(self.device)
log_latent(initial_latent, "initial_latent")
if t_start is not None:
t_start = num_steps - t_start + 1
sigmas = self.cv_denoiser.get_sigmas(num_steps)
sigmas = self.cv_denoiser.get_sigmas(num_steps)[t_start:]
# if our number of steps is zero, just return the initial latent
if sigmas.nelement() == 0:
if orig_latent is not None:
return orig_latent
return initial_latent
x = initial_latent * sigmas[0]
log_latent(x, "initial_sigma_noised_tensor")
model_wrap_cfg = CFGDenoiser(self.cv_denoiser)
mask_noise = None
if mask is not None:
mask_noise = torch.randn_like(initial_latent, device="cpu").to(
initial_latent.device
)
def callback(data):
log_latent(data["x"], "noisy_latent")
log_latent(data["denoised"], "noise_pred")
log_latent(data["denoised"], "noise_pred c")
samples = self.sampler_func(
model=model_wrap_cfg,
@ -63,6 +77,7 @@ class KDiffusionSampler:
"uncond": neutral_conditioning,
"cond_scale": guidance_scale,
"mask": mask,
"mask_noise": mask_noise,
"orig_latent": orig_latent,
},
disable=False,

@ -5,6 +5,7 @@ import pytest
from imaginairy import LazyLoadingImage
from imaginairy.api import imagine, imagine_image_files, prompt_normalized
from imaginairy.img_utils import pillow_fit_image_within
from imaginairy.samplers.base import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import get_device
@ -13,7 +14,10 @@ from . import TESTS_FOLDER
device_sampler_type_test_cases = {
"mps:0": [
("plms", "78539ae3a3097dc8232da6d630551ab3"),
("ddim", ("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e")),
(
"ddim",
("828fc143cd40586347b2f8403c288c9b", "4c7905d4a36f6f9c456b7e074b52707e"),
),
("k_lms", "53d25e59add39c8447537be30e4eff4b"),
("k_dpm_2", "5108bceb58a38d88a585f37b2ba1b072"),
("k_dpm_2_a", "20396daa6c920d1cfd6db90e73558c01"),
@ -51,10 +55,22 @@ device_sampler_type_test_cases_img_2_img = {
"mps:0": {
("plms", "0d9c40c348cdac7bdc8d5a472f378f42"),
("ddim", "0d9c40c348cdac7bdc8d5a472f378f42"),
("k_lms", ""),
("k_dpm_2", ""),
("k_dpm_2_a", ""),
("k_euler", ""),
("k_euler_a", ""),
("k_heun", ""),
},
"cuda": {
("plms", "841723966344dd8678aee1ce5f9cbb3d"),
("ddim", "1f0d72370fabcf2ff716e4068d5b2360"),
("k_lms", ""),
("k_dpm_2", ""),
("k_dpm_2_a", ""),
("k_euler", ""),
("k_euler_a", ""),
("k_heun", ""),
},
}
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
@ -87,12 +103,30 @@ def test_img2img_beach_to_sunset(sampler_type, expected_md5, filename_base_for_o
device_sampler_type_test_cases_img_2_img = {
"mps:0": {
("plms", ("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1")),
("ddim", ("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf")),
(
"plms",
("e9bb714771f7984e61debabc4bb3cd22", "af344c404de70da5db519869f8fcd0c1"),
),
(
"ddim",
("62bacc4ae391e6775a3723c88738ec61", "5f0d2ee426e1bb6ccc1d57dfdd8c73bf"),
),
("k_lms", tuple()),
("k_dpm_2", tuple()),
("k_dpm_2_a", tuple()),
("k_euler", tuple()),
("k_euler_a", tuple()),
("k_heun", tuple()),
},
"cuda": {
("plms", "b8c7b52da977c1531a9a61c0a082404c"),
("ddim", "d6784710dd78e4cb628aba28322b04cf"),
("plms", ("b8c7b52da977c1531a9a61c0a082404c",)),
("ddim", ("d6784710dd78e4cb628aba28322b04cf",)),
("k_lms", ("3246b588155f430a79d08a0b1c7287f5",)),
("k_dpm_2", ("724fa459adec6a7b3ebb523263dd5176",)),
("k_dpm_2_a", ("5c36fa9c051db80e3969c63d500340f4",)),
("k_euler", ("d6800b8a3e31f81fb3902d34ee786b33",)),
("k_euler_a", ("6477863f35d0c9032b959a9cc7a0b61c",)),
("k_heun", ("1ed62ad0cfd03dba8b487a36259833a3",)),
},
}
sampler_type_test_cases_img_2_img = device_sampler_type_test_cases_img_2_img.get(
@ -137,9 +171,8 @@ def test_img_to_img_from_url_cats(
assert result.md5() in expected_md5
# @pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("sampler_type", ["ddim", "plms"])
@pytest.mark.parametrize("sampler_type", SAMPLER_TYPE_OPTIONS)
@pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1])
def test_img_to_img_fruit_2_gold(
filename_base_for_outputs, sampler_type, init_strength
@ -153,7 +186,7 @@ def test_img_to_img_fruit_2_gold(
prompt_strength=12,
init_image=img,
init_image_strength=init_strength,
mask_prompt="(fruit{*2} OR stem{*5} OR fruit stem{*3})",
mask_prompt="(fruit{*2} OR stem{*10} OR fruit stem{*3})",
mask_mode="replace",
steps=80,
seed=1,

Loading…
Cancel
Save