mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-05 12:00:15 +00:00
fix: inpainting now matches photo at high generation strength
- 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image
would be uncorrelated to the rest of the surrounding image. It created terrible looking images.
- fix: mask boundaries are more accurate
This commit is contained in:
parent
d563e0c7fb
commit
6cae290038
10
README.md
10
README.md
@ -185,11 +185,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
## ChangeLog
|
## ChangeLog
|
||||||
|
|
||||||
**2.0.0**
|
**2.0.0**
|
||||||
- feature: interactive prompt added. access by running `aimg`
|
- 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image
|
||||||
- feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
|
would be uncorrelated to the rest of the surrounding image. It created terrible looking images.
|
||||||
|
- 🎉 feature: interactive prompt added. access by running `aimg`
|
||||||
|
- 🎉 feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase.
|
||||||
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`)
|
Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`)
|
||||||
- feature: apply mask edits to original files with `mask_modify_original` (on by default)
|
- 🎉 feature: apply mask edits to original files with `mask_modify_original` (on by default)
|
||||||
- feature: auto-rotate images if exif data specifies to do so
|
- feature: auto-rotate images if exif data specifies to do so
|
||||||
|
- fix: mask boundaries are more accurate
|
||||||
- fix: accept mask images in command line
|
- fix: accept mask images in command line
|
||||||
- fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1
|
- fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1
|
||||||
|
|
||||||
@ -234,7 +237,6 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
|||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
|
|
||||||
- refactor how output versions are selected (upscaled, modified original, etc)
|
|
||||||
- performance optimizations
|
- performance optimizations
|
||||||
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
|
- ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx
|
||||||
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#
|
- ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements#
|
||||||
|
@ -122,13 +122,13 @@ def imagine_image_files(
|
|||||||
prompt = result.prompt
|
prompt = result.prompt
|
||||||
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}"
|
||||||
|
|
||||||
for image_type, img in result.images.items():
|
for image_type in result.images:
|
||||||
subpath = os.path.join(outdir, image_type)
|
subpath = os.path.join(outdir, image_type)
|
||||||
os.makedirs(subpath, exist_ok=True)
|
os.makedirs(subpath, exist_ok=True)
|
||||||
filepath = os.path.join(
|
filepath = os.path.join(
|
||||||
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
|
subpath, f"{basefilename}_[{image_type}].{output_file_extension}"
|
||||||
)
|
)
|
||||||
result.save(filepath)
|
result.save(filepath, image_type=image_type)
|
||||||
logger.info(f" 🖼 [{image_type}] saved to: {filepath}")
|
logger.info(f" 🖼 [{image_type}] saved to: {filepath}")
|
||||||
base_count += 1
|
base_count += 1
|
||||||
del result
|
del result
|
||||||
@ -198,7 +198,12 @@ def imagine(
|
|||||||
sampler_type = prompt.sampler_type
|
sampler_type = prompt.sampler_type
|
||||||
|
|
||||||
sampler = get_sampler(sampler_type, model)
|
sampler = get_sampler(sampler_type, model)
|
||||||
mask, mask_image, mask_image_orig = None, None, None
|
mask, mask_image, mask_image_orig, mask_grayscale = (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
if prompt.init_image:
|
if prompt.init_image:
|
||||||
generation_strength = 1 - prompt.init_image_strength
|
generation_strength = 1 - prompt.init_image_strength
|
||||||
t_enc = int(prompt.steps * generation_strength)
|
t_enc = int(prompt.steps * generation_strength)
|
||||||
@ -218,7 +223,7 @@ def imagine(
|
|||||||
init_image_t = pillow_img_to_torch_image(init_image)
|
init_image_t = pillow_img_to_torch_image(init_image)
|
||||||
|
|
||||||
if prompt.mask_prompt:
|
if prompt.mask_prompt:
|
||||||
mask_image = get_img_mask(
|
mask_image, mask_grayscale = get_img_mask(
|
||||||
init_image, prompt.mask_prompt, threshold=0.1
|
init_image, prompt.mask_prompt, threshold=0.1
|
||||||
)
|
)
|
||||||
elif prompt.mask_image:
|
elif prompt.mask_image:
|
||||||
@ -239,7 +244,7 @@ def imagine(
|
|||||||
mask_image.width // downsampling_factor,
|
mask_image.width // downsampling_factor,
|
||||||
mask_image.height // downsampling_factor,
|
mask_image.height // downsampling_factor,
|
||||||
),
|
),
|
||||||
resample=Image.Resampling.NEAREST,
|
resample=Image.Resampling.LANCZOS,
|
||||||
)
|
)
|
||||||
log_img(mask_image, "latent_mask")
|
log_img(mask_image, "latent_mask")
|
||||||
|
|
||||||
@ -256,9 +261,11 @@ def imagine(
|
|||||||
|
|
||||||
log_latent(init_latent, "init_latent")
|
log_latent(init_latent, "init_latent")
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
|
noise = torch.randn_like(init_latent, device="cpu").to(get_device())
|
||||||
z_enc = sampler.stochastic_encode(
|
z_enc = sampler.stochastic_encode(
|
||||||
init_latent,
|
init_latent,
|
||||||
torch.tensor([t_enc - 1]).to(get_device()),
|
torch.tensor([t_enc - 1]).to(get_device()),
|
||||||
|
noise=noise,
|
||||||
)
|
)
|
||||||
log_latent(z_enc, "z_enc")
|
log_latent(z_enc, "z_enc")
|
||||||
|
|
||||||
@ -297,12 +304,12 @@ def imagine(
|
|||||||
x_sample_8_orig = x_sample.astype(np.uint8)
|
x_sample_8_orig = x_sample.astype(np.uint8)
|
||||||
img = Image.fromarray(x_sample_8_orig)
|
img = Image.fromarray(x_sample_8_orig)
|
||||||
if mask_image_orig and init_image:
|
if mask_image_orig and init_image:
|
||||||
mask_image_orig = mask_image_orig.filter(
|
mask_final = mask_image_orig.filter(
|
||||||
ImageFilter.GaussianBlur(radius=3)
|
ImageFilter.GaussianBlur(radius=3)
|
||||||
)
|
)
|
||||||
log_img(mask_image_orig, "reconstituting mask")
|
log_img(mask_final, "reconstituting mask")
|
||||||
mask_image_orig = ImageOps.invert(mask_image_orig)
|
mask_final = ImageOps.invert(mask_final)
|
||||||
img = Image.composite(img, init_image, mask_image_orig)
|
img = Image.composite(img, init_image, mask_final)
|
||||||
log_img(img, "reconstituted image")
|
log_img(img, "reconstituted image")
|
||||||
|
|
||||||
upscaled_img = None
|
upscaled_img = None
|
||||||
@ -328,7 +335,11 @@ def imagine(
|
|||||||
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
|
upscaled_img = enhance_faces(upscaled_img, fidelity=0.8)
|
||||||
|
|
||||||
# put the newly generated patch back into the original, full size image
|
# put the newly generated patch back into the original, full size image
|
||||||
if prompt.mask_modify_original and mask_image_orig and prompt.init_image:
|
if (
|
||||||
|
prompt.mask_modify_original
|
||||||
|
and mask_image_orig
|
||||||
|
and prompt.init_image
|
||||||
|
):
|
||||||
img_to_add_back_to_original = (
|
img_to_add_back_to_original = (
|
||||||
upscaled_img if upscaled_img else img
|
upscaled_img if upscaled_img else img
|
||||||
)
|
)
|
||||||
@ -348,8 +359,8 @@ def imagine(
|
|||||||
log_img(mask_for_orig_size, "mask for original image size")
|
log_img(mask_for_orig_size, "mask for original image size")
|
||||||
|
|
||||||
rebuilt_orig_img = Image.composite(
|
rebuilt_orig_img = Image.composite(
|
||||||
img_to_add_back_to_original,
|
|
||||||
prompt.init_image,
|
prompt.init_image,
|
||||||
|
img_to_add_back_to_original,
|
||||||
mask_for_orig_size,
|
mask_for_orig_size,
|
||||||
)
|
)
|
||||||
log_img(rebuilt_orig_img, "reconstituted original")
|
log_img(rebuilt_orig_img, "reconstituted original")
|
||||||
@ -359,7 +370,9 @@ def imagine(
|
|||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
upscaled_img=upscaled_img,
|
upscaled_img=upscaled_img,
|
||||||
is_nsfw=is_nsfw_img,
|
is_nsfw=is_nsfw_img,
|
||||||
modified_original_img=rebuilt_orig_img,
|
modified_original=rebuilt_orig_img,
|
||||||
|
mask_binary=mask_image_orig,
|
||||||
|
mask_grayscale=mask_grayscale,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
from kornia.filters import median_blur
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from imaginairy.img_log import log_img
|
from imaginairy.img_log import log_img
|
||||||
@ -41,10 +44,30 @@ def get_img_mask(
|
|||||||
mask_cache = get_img_masks(img, descriptions)
|
mask_cache = get_img_masks(img, descriptions)
|
||||||
mask = parsed_mask.apply_masks(mask_cache)
|
mask = parsed_mask.apply_masks(mask_cache)
|
||||||
log_img(mask, "combined mask")
|
log_img(mask, "combined mask")
|
||||||
|
|
||||||
|
# try to blur the square shaped artifacts somewhat
|
||||||
|
mask = median_blur(mask.unsqueeze(dim=0).unsqueeze(dim=0), (11, 11)).squeeze()
|
||||||
|
log_img(mask, "median blurred")
|
||||||
|
|
||||||
|
kernel = np.ones((5, 5), np.uint8)
|
||||||
|
mask_g = mask.clone()
|
||||||
|
|
||||||
|
# trial and error shows 0.5 threshold has the best "shape"
|
||||||
if threshold is not None:
|
if threshold is not None:
|
||||||
mask[mask < threshold] = 0
|
mask[mask < 0.5] = 0
|
||||||
mask[mask >= threshold] = 1
|
mask[mask >= 0.5] = 1
|
||||||
return transforms.ToPILImage()(mask)
|
log_img(mask, f"mask threshold {0.5}")
|
||||||
|
|
||||||
|
mask_np = mask.cpu().numpy()
|
||||||
|
smoother_strength = 5
|
||||||
|
# grow the mask area to make sure we've masked the thing we care about
|
||||||
|
for _ in range(smoother_strength):
|
||||||
|
mask_np = cv2.dilate(mask_np, kernel)
|
||||||
|
# todo: add an outer blur (not gaussian)
|
||||||
|
mask = torch.from_numpy(mask_np)
|
||||||
|
log_img(mask, "mask after closing (dilation then erosion)")
|
||||||
|
|
||||||
|
return transforms.ToPILImage()(mask), transforms.ToPILImage()(mask_g)
|
||||||
|
|
||||||
|
|
||||||
def get_img_masks(img, mask_descriptions: Sequence[str]):
|
def get_img_masks(img, mask_descriptions: Sequence[str]):
|
||||||
@ -66,7 +89,6 @@ def get_img_masks(img, mask_descriptions: Sequence[str]):
|
|||||||
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
|
img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions
|
||||||
)[0]
|
)[0]
|
||||||
preds = transforms.Resize(orig_size)(preds)
|
preds = transforms.Resize(orig_size)(preds)
|
||||||
preds = transforms.GaussianBlur(kernel_size=9)(preds)
|
|
||||||
|
|
||||||
preds = [torch.sigmoid(p[0]) for p in preds]
|
preds = [torch.sigmoid(p[0]) for p in preds]
|
||||||
|
|
||||||
|
@ -359,8 +359,12 @@ class DDIMSampler:
|
|||||||
assert orig_latent is not None
|
assert orig_latent is not None
|
||||||
xdec_orig = self.model.q_sample(orig_latent, ts)
|
xdec_orig = self.model.q_sample(orig_latent, ts)
|
||||||
log_latent(xdec_orig, "xdec_orig")
|
log_latent(xdec_orig, "xdec_orig")
|
||||||
log_latent(xdec_orig * mask, "masked_xdec_orig")
|
# this helps prevent the weird disjointed images that can happen with masking
|
||||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
hint_strength = 0.8
|
||||||
|
xdec_orig_with_hints = (
|
||||||
|
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||||
|
)
|
||||||
|
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
||||||
log_latent(x_dec, "x_dec")
|
log_latent(x_dec, "x_dec")
|
||||||
|
|
||||||
x_dec, pred_x0 = self.p_sample_ddim(
|
x_dec, pred_x0 = self.p_sample_ddim(
|
||||||
|
@ -388,6 +388,7 @@ class PLMSSampler:
|
|||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
mask=None,
|
mask=None,
|
||||||
orig_latent=None,
|
orig_latent=None,
|
||||||
|
noise=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
timesteps = self.ddim_timesteps[:t_start]
|
timesteps = self.ddim_timesteps[:t_start]
|
||||||
@ -398,7 +399,15 @@ class PLMSSampler:
|
|||||||
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
|
iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps)
|
||||||
x_dec = x_latent
|
x_dec = x_latent
|
||||||
old_eps = []
|
old_eps = []
|
||||||
|
log_latent(x_dec, "x_dec")
|
||||||
|
|
||||||
|
# not sure what the downside of using the same noise throughout the process would be...
|
||||||
|
# seems to work fine. maybe it runs faster?
|
||||||
|
noise = (
|
||||||
|
torch.randn_like(x_dec, device="cpu").to(x_dec.device)
|
||||||
|
if noise is None
|
||||||
|
else noise
|
||||||
|
)
|
||||||
for i, step in enumerate(iterator):
|
for i, step in enumerate(iterator):
|
||||||
index = total_steps - i - 1
|
index = total_steps - i - 1
|
||||||
ts = torch.full(
|
ts = torch.full(
|
||||||
@ -413,10 +422,14 @@ class PLMSSampler:
|
|||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
assert orig_latent is not None
|
assert orig_latent is not None
|
||||||
xdec_orig = self.model.q_sample(orig_latent, ts)
|
xdec_orig = self.model.q_sample(orig_latent, ts, noise)
|
||||||
log_latent(xdec_orig, "xdec_orig")
|
log_latent(xdec_orig, "xdec_orig")
|
||||||
log_latent(xdec_orig * mask, "masked_xdec_orig")
|
# this helps prevent the weird disjointed images that can happen with masking
|
||||||
x_dec = xdec_orig * mask + (1.0 - mask) * x_dec
|
hint_strength = 0.8
|
||||||
|
xdec_orig_with_hints = (
|
||||||
|
xdec_orig * (1 - hint_strength) + orig_latent * hint_strength
|
||||||
|
)
|
||||||
|
x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec
|
||||||
log_latent(x_dec, "x_dec")
|
log_latent(x_dec, "x_dec")
|
||||||
|
|
||||||
x_dec, pred_x0, e_t = self.p_sample_plms(
|
x_dec, pred_x0, e_t = self.p_sample_plms(
|
||||||
@ -446,5 +459,6 @@ class PLMSSampler:
|
|||||||
img_callback(pred_x0, "pred_x0")
|
img_callback(pred_x0, "pred_x0")
|
||||||
|
|
||||||
log_latent(x_dec, f"x_dec {i}")
|
log_latent(x_dec, f"x_dec {i}")
|
||||||
|
log_latent(x_dec, f"e_t {i}")
|
||||||
log_latent(pred_x0, f"pred_x0 {i}")
|
log_latent(pred_x0, f"pred_x0 {i}")
|
||||||
return x_dec
|
return x_dec
|
||||||
|
@ -138,7 +138,6 @@ class ImaginePrompt:
|
|||||||
self.mask_modify_original = mask_modify_original
|
self.mask_modify_original = mask_modify_original
|
||||||
self.tile_mode = tile_mode
|
self.tile_mode = tile_mode
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_text(self):
|
def prompt_text(self):
|
||||||
if len(self.prompts) == 1:
|
if len(self.prompts) == 1:
|
||||||
@ -186,7 +185,7 @@ class ImagineResult:
|
|||||||
prompt: ImaginePrompt,
|
prompt: ImaginePrompt,
|
||||||
is_nsfw,
|
is_nsfw,
|
||||||
upscaled_img=None,
|
upscaled_img=None,
|
||||||
modified_original_img=None,
|
modified_original=None,
|
||||||
mask_binary=None,
|
mask_binary=None,
|
||||||
mask_grayscale=None,
|
mask_grayscale=None,
|
||||||
):
|
):
|
||||||
@ -197,8 +196,8 @@ class ImagineResult:
|
|||||||
if upscaled_img:
|
if upscaled_img:
|
||||||
self.images["upscaled"] = upscaled_img
|
self.images["upscaled"] = upscaled_img
|
||||||
|
|
||||||
if modified_original_img:
|
if modified_original:
|
||||||
self.images["modified_original"] = modified_original_img
|
self.images["modified_original"] = modified_original
|
||||||
|
|
||||||
if mask_binary:
|
if mask_binary:
|
||||||
self.images["mask_binary"] = mask_binary
|
self.images["mask_binary"] = mask_binary
|
||||||
|
@ -40,8 +40,15 @@ def test_clip_masking():
|
|||||||
"*1",
|
"*1",
|
||||||
"*10",
|
"*10",
|
||||||
]:
|
]:
|
||||||
pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}")
|
pred_bin, pred_grayscale = get_img_mask(
|
||||||
pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png")
|
img, f"(head OR face){{{mask_modifier}}}", threshold=0.1
|
||||||
|
)
|
||||||
|
pred_grayscale.save(
|
||||||
|
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_g.png"
|
||||||
|
)
|
||||||
|
pred_bin.save(
|
||||||
|
f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_bin.png"
|
||||||
|
)
|
||||||
|
|
||||||
prompt = ImaginePrompt(
|
prompt = ImaginePrompt(
|
||||||
"professional photo of a woman",
|
"professional photo of a woman",
|
||||||
@ -57,8 +64,9 @@ def test_clip_masking():
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = next(imagine(prompt))
|
result = next(imagine(prompt))
|
||||||
result.modified_original_img.save(
|
result.save(
|
||||||
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png"
|
f"{TESTS_FOLDER}/test_output/earring_mask_photo.png",
|
||||||
|
image_type="modified_original",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user