fix: inpainting now matches photo at high generation strength

- 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength.  Previously if the generation strength was high enough the generated image
would be uncorrelated to the rest of the surrounding image.  It created terrible looking images.
 - fix: mask boundaries are more accurate
This commit is contained in:
Bryce 2022-09-25 21:55:25 -07:00 committed by Bryce Drennan
parent d563e0c7fb
commit 6cae290038
7 changed files with 95 additions and 33 deletions

View File

@ -185,11 +185,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog ## ChangeLog
**2.0.0** **2.0.0**
- feature: interactive prompt added. access by running `aimg` - 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image
- feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase. would be uncorrelated to the rest of the surrounding image. It created terrible looking images.
- 🎉 feature: interactive prompt added. access by running `aimg`
- 🎉 feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`) Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`)
- feature: apply mask edits to original files with `mask_modify_original` (on by default) - 🎉 feature: apply mask edits to original files with `mask_modify_original` (on by default)
- feature: auto-rotate images if exif data specifies to do so - feature: auto-rotate images if exif data specifies to do so
- fix: mask boundaries are more accurate
- fix: accept mask images in command line - fix: accept mask images in command line
- fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1 - fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1
@ -234,7 +237,6 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## Todo ## Todo
- refactor how output versions are selected (upscaled, modified original, etc)
- performance optimizations - performance optimizations
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx - ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements# - ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#

View File

@ -122,13 +122,13 @@ def imagine_image_files(
prompt = result.prompt prompt = result.prompt
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}" basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
for image_type, img in result.images.items(): for image_type in result.images:
subpath = os.path.join(outdir, image_type) subpath = os.path.join(outdir, image_type)
os.makedirs(subpath, exist_ok=True) os.makedirs(subpath, exist_ok=True)
filepath = os.path.join( filepath = os.path.join(
subpath, f"{basefilename}_[{image_type}].{output_file_extension}" subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
) )
result.save(filepath) result.save(filepath, image_type=image_type)
logger.info(f" 🖼 [{image_type}] saved to: {filepath}") logger.info(f" 🖼 [{image_type}] saved to: {filepath}")
base_count += 1 base_count += 1
del result del result
@ -198,7 +198,12 @@ def imagine(
sampler_type = prompt.sampler_type sampler_type = prompt.sampler_type
sampler = get_sampler(sampler_type, model) sampler = get_sampler(sampler_type, model)
mask, mask_image, mask_image_orig = None, None, None mask, mask_image, mask_image_orig, mask_grayscale = (
None,
None,
None,
None,
)
if prompt.init_image: if prompt.init_image:
generation_strength = 1 - prompt.init_image_strength generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength) t_enc = int(prompt.steps * generation_strength)
@ -218,7 +223,7 @@ def imagine(
init_image_t = pillow_img_to_torch_image(init_image) init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt: if prompt.mask_prompt:
mask_image = get_img_mask( mask_image, mask_grayscale = get_img_mask(
init_image, prompt.mask_prompt, threshold=0.1 init_image, prompt.mask_prompt, threshold=0.1
) )
elif prompt.mask_image: elif prompt.mask_image:
@ -239,7 +244,7 @@ def imagine(
mask_image.width // downsampling_factor, mask_image.width // downsampling_factor,
mask_image.height // downsampling_factor, mask_image.height // downsampling_factor,
), ),
resample=Image.Resampling.NEAREST, resample=Image.Resampling.LANCZOS,
) )
log_img(mask_image, "latent_mask") log_img(mask_image, "latent_mask")
@ -256,9 +261,11 @@ def imagine(
log_latent(init_latent, "init_latent") log_latent(init_latent, "init_latent")
# encode (scaled latent) # encode (scaled latent)
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
z_enc = sampler.stochastic_encode( z_enc = sampler.stochastic_encode(
init_latent, init_latent,
torch.tensor([t_enc - 1]).to(get_device()), torch.tensor([t_enc - 1]).to(get_device()),
noise=noise,
) )
log_latent(z_enc, "z_enc") log_latent(z_enc, "z_enc")
@ -297,12 +304,12 @@ def imagine(
x_sample_8_orig = x_sample.astype(np.uint8) x_sample_8_orig = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample_8_orig) img = Image.fromarray(x_sample_8_orig)
if mask_image_orig and init_image: if mask_image_orig and init_image:
mask_image_orig = mask_image_orig.filter( mask_final = mask_image_orig.filter(
ImageFilter.GaussianBlur(radius=3) ImageFilter.GaussianBlur(radius=3)
) )
log_img(mask_image_orig, "reconstituting mask") log_img(mask_final, "reconstituting mask")
mask_image_orig = ImageOps.invert(mask_image_orig) mask_final = ImageOps.invert(mask_final)
img = Image.composite(img, init_image, mask_image_orig) img = Image.composite(img, init_image, mask_final)
log_img(img, "reconstituted image") log_img(img, "reconstituted image")
upscaled_img = None upscaled_img = None
@ -328,7 +335,11 @@ def imagine(
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8) upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
# put the newly generated patch back into the original, full size image # put the newly generated patch back into the original, full size image
if prompt.mask_modify_original and mask_image_orig and prompt.init_image: if (
prompt.mask_modify_original
and mask_image_orig
and prompt.init_image
):
img_to_add_back_to_original = ( img_to_add_back_to_original = (
upscaled_img if upscaled_img else img upscaled_img if upscaled_img else img
) )
@ -348,8 +359,8 @@ def imagine(
log_img(mask_for_orig_size, "mask for original image size") log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite( rebuilt_orig_img = Image.composite(
img_to_add_back_to_original,
prompt.init_image, prompt.init_image,
img_to_add_back_to_original,
mask_for_orig_size, mask_for_orig_size,
) )
log_img(rebuilt_orig_img, "reconstituted original") log_img(rebuilt_orig_img, "reconstituted original")
@ -359,7 +370,9 @@ def imagine(
prompt=prompt, prompt=prompt,
upscaled_img=upscaled_img, upscaled_img=upscaled_img,
is_nsfw=is_nsfw_img, is_nsfw=is_nsfw_img,
modified_original_img=rebuilt_orig_img, modified_original=rebuilt_orig_img,
mask_binary=mask_image_orig,
mask_grayscale=mask_grayscale,
) )

View File

@ -1,8 +1,11 @@
from functools import lru_cache from functools import lru_cache
from typing import Optional, Sequence from typing import Optional, Sequence
import cv2
import numpy as np
import PIL.Image import PIL.Image
import torch import torch
from kornia.filters import median_blur
from torchvision import transforms from torchvision import transforms
from imaginairy.img_log import log_img from imaginairy.img_log import log_img
@ -41,10 +44,30 @@ def get_img_mask(
mask_cache = get_img_masks(img, descriptions) mask_cache = get_img_masks(img, descriptions)
mask = parsed_mask.apply_masks(mask_cache) mask = parsed_mask.apply_masks(mask_cache)
log_img(mask, "combined mask") log_img(mask, "combined mask")
# try to blur the square shaped artifacts somewhat
mask = median_blur(mask.unsqueeze(dim=0).unsqueeze(dim=0), (11, 11)).squeeze()
log_img(mask, "median blurred")
kernel = np.ones((5, 5), np.uint8)
mask_g = mask.clone()
# trial and error shows 0.5 threshold has the best "shape"
if threshold is not None: if threshold is not None:
mask[mask < threshold] = 0 mask[mask < 0.5] = 0
mask[mask >= threshold] = 1 mask[mask >= 0.5] = 1
return transforms.ToPILImage()(mask) log_img(mask, f"mask threshold {0.5}")
mask_np = mask.cpu().numpy()
smoother_strength = 5
# grow the mask area to make sure we've masked the thing we care about
for _ in range(smoother_strength):
mask_np = cv2.dilate(mask_np, kernel)
# todo: add an outer blur (not gaussian)
mask = torch.from_numpy(mask_np)
log_img(mask, "mask after closing (dilation then erosion)")
return transforms.ToPILImage()(mask), transforms.ToPILImage()(mask_g)
def get_img_masks(img, mask_descriptions: Sequence[str]): def get_img_masks(img, mask_descriptions: Sequence[str]):
@ -66,7 +89,6 @@ def get_img_masks(img, mask_descriptions: Sequence[str]):
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
)[0] )[0]
preds = transforms.Resize(orig_size)(preds) preds = transforms.Resize(orig_size)(preds)
preds = transforms.GaussianBlur(kernel_size=9)(preds)
preds = [torch.sigmoid(p[0]) for p in preds] preds = [torch.sigmoid(p[0]) for p in preds]

View File

@ -359,8 +359,12 @@ class DDIMSampler:
assert orig_latent is not None assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts) xdec_orig = self.model.q_sample(orig_latent, ts)
log_latent(xdec_orig, "xdec_orig") log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig") # this helps prevent the weird disjointed images that can happen with masking
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec hint_strength = 0.8
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec") log_latent(x_dec, "x_dec")
x_dec, pred_x0 = self.p_sample_ddim( x_dec, pred_x0 = self.p_sample_ddim(

View File

@ -388,6 +388,7 @@ class PLMSSampler:
temperature=1.0, temperature=1.0,
mask=None, mask=None,
orig_latent=None, orig_latent=None,
noise=None,
): ):
timesteps = self.ddim_timesteps[:t_start] timesteps = self.ddim_timesteps[:t_start]
@ -398,7 +399,15 @@ class PLMSSampler:
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps) iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
x_dec = x_latent x_dec = x_latent
old_eps = [] old_eps = []
log_latent(x_dec, "x_dec")
# not sure what the downside of using the same noise throughout the process would be...
# seems to work fine. maybe it runs faster?
noise = (
torch.randn_like(x_dec, device="cpu").to(x_dec.device)
if noise is None
else noise
)
for i, step in enumerate(iterator): for i, step in enumerate(iterator):
index = total_steps - i - 1 index = total_steps - i - 1
ts = torch.full( ts = torch.full(
@ -413,10 +422,14 @@ class PLMSSampler:
if mask is not None: if mask is not None:
assert orig_latent is not None assert orig_latent is not None
xdec_orig = self.model.q_sample(orig_latent, ts) xdec_orig = self.model.q_sample(orig_latent, ts, noise)
log_latent(xdec_orig, "xdec_orig") log_latent(xdec_orig, "xdec_orig")
log_latent(xdec_orig * mask, "masked_xdec_orig") # this helps prevent the weird disjointed images that can happen with masking
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec hint_strength = 0.8
xdec_orig_with_hints = (
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
)
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
log_latent(x_dec, "x_dec") log_latent(x_dec, "x_dec")
x_dec, pred_x0, e_t = self.p_sample_plms( x_dec, pred_x0, e_t = self.p_sample_plms(
@ -446,5 +459,6 @@ class PLMSSampler:
img_callback(pred_x0, "pred_x0") img_callback(pred_x0, "pred_x0")
log_latent(x_dec, f"x_dec {i}") log_latent(x_dec, f"x_dec {i}")
log_latent(x_dec, f"e_t {i}")
log_latent(pred_x0, f"pred_x0 {i}") log_latent(pred_x0, f"pred_x0 {i}")
return x_dec return x_dec

View File

@ -138,7 +138,6 @@ class ImaginePrompt:
self.mask_modify_original = mask_modify_original self.mask_modify_original = mask_modify_original
self.tile_mode = tile_mode self.tile_mode = tile_mode
@property @property
def prompt_text(self): def prompt_text(self):
if len(self.prompts) == 1: if len(self.prompts) == 1:
@ -186,7 +185,7 @@ class ImagineResult:
prompt: ImaginePrompt, prompt: ImaginePrompt,
is_nsfw, is_nsfw,
upscaled_img=None, upscaled_img=None,
modified_original_img=None, modified_original=None,
mask_binary=None, mask_binary=None,
mask_grayscale=None, mask_grayscale=None,
): ):
@ -197,8 +196,8 @@ class ImagineResult:
if upscaled_img: if upscaled_img:
self.images["upscaled"] = upscaled_img self.images["upscaled"] = upscaled_img
if modified_original_img: if modified_original:
self.images["modified_original"] = modified_original_img self.images["modified_original"] = modified_original
if mask_binary: if mask_binary:
self.images["mask_binary"] = mask_binary self.images["mask_binary"] = mask_binary

View File

@ -40,8 +40,15 @@ def test_clip_masking():
"*1", "*1",
"*10", "*10",
]: ]:
pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}") pred_bin, pred_grayscale = get_img_mask(
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png") img, f"(head OR face){{{mask_modifier}}}", threshold=0.1
)
pred_grayscale.save(
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_g.png"
)
pred_bin.save(
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_bin.png"
)
prompt = ImaginePrompt( prompt = ImaginePrompt(
"professional photo of a woman", "professional photo of a woman",
@ -57,8 +64,9 @@ def test_clip_masking():
) )
result = next(imagine(prompt)) result = next(imagine(prompt))
result.modified_original_img.save( result.save(
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png" f"{TESTS_FOLDER}/test_output/earring_mask_photo.png",
image_type="modified_original",
) )