fix: inpainting now matches photo at high generation strength

- 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength.  Previously if the generation strength was high enough the generated image
would be uncorrelated to the rest of the surrounding image.  It created terrible looking images.
 - fix: mask boundaries are more accurate
pull/30/head
Bryce 2 years ago committed by Bryce Drennan
parent d563e0c7fb
commit 6cae290038

@ -185,11 +185,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**2.0.0**
- feature: interactive prompt added. access by running `aimg`
- feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
- 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image
would be uncorrelated to the rest of the surrounding image. It created terrible looking images.
- 🎉 feature: interactive prompt added. access by running `aimg`
- 🎉 feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`)
- feature: apply mask edits to original files with `mask_modify_original` (on by default)
- 🎉 feature: apply mask edits to original files with `mask_modify_original` (on by default)
- feature: auto-rotate images if exif data specifies to do so
- fix: mask boundaries are more accurate
- fix: accept mask images in command line
- fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1
@ -234,7 +237,6 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## Todo
- refactor how output versions are selected (upscaled, modified original, etc)
- performance optimizations
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#

@ -122,13 +122,13 @@ def imagine_image_files(
prompt = result.prompt
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
for image_type, img in result.images.items():
for image_type in result.images:
subpath = os.path.join(outdir, image_type)
os.makedirs(subpath, exist_ok=True)
filepath = os.path.join(
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
)
result.save(filepath)
result.save(filepath, image_type=image_type)
logger.info(f" 🖼 [{image_type}] saved to: {filepath}")
base_count += 1
del result
@ -198,7 +198,12 @@ def imagine(
sampler_type = prompt.sampler_type
sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig = None, None, None
mask, mask_image, mask_image_orig, mask_grayscale = (
None,
None,
None,
None,
)
if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength)
@ -218,7 +223,7 @@ def imagine(
init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt:
mask_image = get_img_mask(
mask_image, mask_grayscale = get_img_mask(
init_image, prompt.mask_prompt, threshold=0.1
)
elif prompt.mask_image:
@ -239,7 +244,7 @@ def imagine(
mask_image.width // downsampling_factor,
mask_image.height // downsampling_factor,
),
resample=Image.Resampling.NEAREST,
resample=Image.Resampling.LANCZOS,
)
log_img(mask_image, "latent_mask")
@ -256,9 +261,11 @@ def imagine(
log_latent(init_latent, "init_latent")
# encode (scaled latent)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
z_enc = sampler.stochastic_encode(
init_latent,
torch.tensor([t_enc - 1]).to(get_device()),
noise=noise,
)
log_latent(z_enc, "z_enc")
@ -297,12 +304,12 @@ def imagine(
x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image:
mask_image_orig = mask_image_orig.filter(
mask_final = mask_image_orig.filter(
ImageFilter.GaussianBlur(radius=3)
)
log_img(mask_image_orig, "reconstituting mask")
mask_image_orig = ImageOps.invert(mask_image_orig)
img = Image.composite(img, init_image, mask_image_orig)
log_img(mask_final, "reconstituting mask")
mask_final = ImageOps.invert(mask_final)
img = Image.composite(img, init_image, mask_final)
log_img(img, "reconstituted image")
upscaled_img = None
@ -328,7 +335,11 @@ def imagine(
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
# put the newly generated patch back into the original, full size image
if prompt.mask_modify_original and mask_image_orig and prompt.init_image:
if (
prompt.mask_modify_original
and mask_image_orig
and prompt.init_image
):
img_to_add_back_to_original = (
upscaled_img if upscaled_img else img
)
@ -348,8 +359,8 @@ def imagine(
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
img_to_add_back_to_original,
prompt.init_image,
img_to_add_back_to_original,
mask_for_orig_size,
)
log_img(rebuilt_orig_img, "reconstituted original")
@ -359,7 +370,9 @@ def imagine(
prompt=prompt,
upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img,
modified_original_img=rebuilt_orig_img,
modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
)

@ -1,8 +1,11 @@
from functools import lru_cache
from typing import Optional, Sequence
import cv2
import numpy as np
import PIL.Image
import torch
from kornia.filters import median_blur
from torchvision import transforms
from imaginairy.img_log import log_img
@ -41,10 +44,30 @@ def get_img_mask(
mask_cache = get_img_masks(img, descriptions)
mask = parsed_mask.apply_masks(mask_cache)
log_img(mask, "combined mask")
# try to blur the square shaped artifacts somewhat
mask = median_blur(mask.unsqueeze(dim=0).unsqueeze(dim=0), (11, 11)).squeeze()
log_img(mask, "median blurred")
kernel = np.ones((5, 5), np.uint8)
mask_g = mask.clone()
# trial and error shows 0.5 threshold has the best "shape"
if threshold is not None:
mask[mask < threshold] = 0
mask[mask >= threshold] = 1
return transforms.ToPILImage()(mask)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
log_img(mask, f"mask threshold {0.5}")
mask_np = mask.cpu().numpy()
smoother_strength = 5
# grow the mask area to make sure we've masked the thing we care about
for _ in range(smoother_strength):
mask_np = cv2.dilate(mask_np, kernel)
# todo: add an outer blur (not gaussian)
mask = torch.from_numpy(mask_np)
log_img(mask, "mask after closing (dilation then erosion)")
return transforms.ToPILImage()(mask), transforms.ToPILImage()(mask_g)
def get_img_masks(img, mask_descriptions: Sequence[str]):
@ -66,7 +89,6 @@ def get_img_masks(img, mask_descriptions: Sequence[str]):
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
)[0]
preds = transforms.Resize(orig_size)(preds)
preds = transforms.GaussianBlur(kernel_size=9)(preds)
preds = [torch.sigmoid(p[0]) for p in preds]

@ -359,8 +359,12 @@ class DDIMSampler:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig")
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec")
x_dec, pred_x0 = self.p_sample_ddim(

@ -388,6 +388,7 @@ class PLMSSampler:
temperature=1.0,
mask=None,
orig_latent=None,
noise=None,
):
timesteps = self.ddim_timesteps[:t_start]
@ -398,7 +399,15 @@ class PLMSSampler:
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
x_dec = x_latent
old_eps = []
log_latent(x_dec, "x_dec")
# not sure what the downside of using the same noise throughout the process would be...
# seems to work fine. maybe it runs faster?
noise = (
torch.randn_like(x_dec, device="cpu").to(x_dec.device)
if noise is None
else noise
)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
@ -413,10 +422,14 @@ class PLMSSampler:
if mask is not None:
assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts)
xdec_orig = self.model.q_sample(orig_latent, ts, noise)
log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig")
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
# this helps prevent the weird disjointed images that can happen with masking
hint_strength = 0.8
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec")
x_dec, pred_x0, e_t = self.p_sample_plms(
@ -446,5 +459,6 @@ class PLMSSampler:
img_callback(pred_x0, "pred_x0")
log_latent(x_dec, f"x_dec {i}")
log_latent(x_dec, f"e_t {i}")
log_latent(pred_x0, f"pred_x0 {i}")
return x_dec

@ -138,7 +138,6 @@ class ImaginePrompt:
self.mask_modify_original = mask_modify_original
self.tile_mode = tile_mode
@property
def prompt_text(self):
if len(self.prompts) == 1:
@ -186,7 +185,7 @@ class ImagineResult:
prompt: ImaginePrompt,
is_nsfw,
upscaled_img=None,
modified_original_img=None,
modified_original=None,
mask_binary=None,
mask_grayscale=None,
):
@ -197,8 +196,8 @@ class ImagineResult:
if upscaled_img:
self.images["upscaled"] = upscaled_img
if modified_original_img:
self.images["modified_original"] = modified_original_img
if modified_original:
self.images["modified_original"] = modified_original
if mask_binary:
self.images["mask_binary"] = mask_binary

@ -40,8 +40,15 @@ def test_clip_masking():
"*1",
"*10",
]:
pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}")
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png")
pred_bin, pred_grayscale = get_img_mask(
img, f"(head OR face){{{mask_modifier}}}", threshold=0.1
)
pred_grayscale.save(
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_g.png"
)
pred_bin.save(
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_bin.png"
)
prompt = ImaginePrompt(
"professional photo of a woman",
@ -57,8 +64,9 @@ def test_clip_masking():
)
result = next(imagine(prompt))
result.modified_original_img.save(
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png"
result.save(
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png",
image_type="modified_original",
)

Loading…
Cancel
Save