diff --git a/README.md b/README.md index d3b84e7..3b7b065 100644 --- a/README.md +++ b/README.md @@ -185,11 +185,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog **2.0.0** - - feature: interactive prompt added. access by running `aimg` - - feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase. + - 🎉 fix: inpainted areas correlate with surrounding image, even at 100% generation strength. Previously if the generation strength was high enough the generated image +would be uncorrelated to the rest of the surrounding image. It created terrible looking images. + - 🎉 feature: interactive prompt added. access by running `aimg` + - 🎉 feature: Specify advanced text based masks using boolean logic and strength modifiers. Mask descriptions must be lowercase. Keywords uppercase. Valid symbols: `AND`, `OR`, `NOT`, `()`, and mask strength modifier `{+0.1}` where `+` can be any of `+ - * /`. Single character boolean operators also work (`|`, `&`, `!`) - - feature: apply mask edits to original files with `mask_modify_original` (on by default) + - 🎉 feature: apply mask edits to original files with `mask_modify_original` (on by default) - feature: auto-rotate images if exif data specifies to do so + - fix: mask boundaries are more accurate - fix: accept mask images in command line - fix: img2img algorithm was wrong and wouldn't at values close to 0 or 1 @@ -234,7 +237,6 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## Todo - - refactor how output versions are selected (upscaled, modified original, etc) - performance optimizations - ✅ https://github.com/huggingface/diffusers/blob/main/docs/source/optimization/fp16.mdx - ✅ https://github.com/CompVis/stable-diffusion/compare/main...Doggettx:stable-diffusion:autocast-improvements# diff --git a/imaginairy/api.py b/imaginairy/api.py index bc766ed..33d630a 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -122,13 +122,13 @@ def imagine_image_files( prompt = result.prompt basefilename = f"{base_count:06}_{prompt.seed}_{prompt.sampler_type}{prompt.steps}_PS{prompt.prompt_strength}_{prompt_normalized(prompt.prompt_text)}" - for image_type, img in result.images.items(): + for image_type in result.images: subpath = os.path.join(outdir, image_type) os.makedirs(subpath, exist_ok=True) filepath = os.path.join( subpath, f"{basefilename}_[{image_type}].{output_file_extension}" ) - result.save(filepath) + result.save(filepath, image_type=image_type) logger.info(f" 🖼 [{image_type}] saved to: {filepath}") base_count += 1 del result @@ -198,7 +198,12 @@ def imagine( sampler_type = prompt.sampler_type sampler = get_sampler(sampler_type, model) - mask, mask_image, mask_image_orig = None, None, None + mask, mask_image, mask_image_orig, mask_grayscale = ( + None, + None, + None, + None, + ) if prompt.init_image: generation_strength = 1 - prompt.init_image_strength t_enc = int(prompt.steps * generation_strength) @@ -218,7 +223,7 @@ def imagine( init_image_t = pillow_img_to_torch_image(init_image) if prompt.mask_prompt: - mask_image = get_img_mask( + mask_image, mask_grayscale = get_img_mask( init_image, prompt.mask_prompt, threshold=0.1 ) elif prompt.mask_image: @@ -239,7 +244,7 @@ def imagine( mask_image.width // downsampling_factor, mask_image.height // downsampling_factor, ), - resample=Image.Resampling.NEAREST, + resample=Image.Resampling.LANCZOS, ) log_img(mask_image, "latent_mask") @@ -256,9 +261,11 @@ def imagine( log_latent(init_latent, "init_latent") # encode (scaled latent) + noise = torch.randn_like(init_latent, device="cpu").to(get_device()) z_enc = sampler.stochastic_encode( init_latent, torch.tensor([t_enc - 1]).to(get_device()), + noise=noise, ) log_latent(z_enc, "z_enc") @@ -297,12 +304,12 @@ def imagine( x_sample_8_orig = x_sample.astype(np.uint8) img = Image.fromarray(x_sample_8_orig) if mask_image_orig and init_image: - mask_image_orig = mask_image_orig.filter( + mask_final = mask_image_orig.filter( ImageFilter.GaussianBlur(radius=3) ) - log_img(mask_image_orig, "reconstituting mask") - mask_image_orig = ImageOps.invert(mask_image_orig) - img = Image.composite(img, init_image, mask_image_orig) + log_img(mask_final, "reconstituting mask") + mask_final = ImageOps.invert(mask_final) + img = Image.composite(img, init_image, mask_final) log_img(img, "reconstituted image") upscaled_img = None @@ -328,7 +335,11 @@ def imagine( upscaled_img = enhance_faces(upscaled_img, fidelity=0.8) # put the newly generated patch back into the original, full size image - if prompt.mask_modify_original and mask_image_orig and prompt.init_image: + if ( + prompt.mask_modify_original + and mask_image_orig + and prompt.init_image + ): img_to_add_back_to_original = ( upscaled_img if upscaled_img else img ) @@ -348,8 +359,8 @@ def imagine( log_img(mask_for_orig_size, "mask for original image size") rebuilt_orig_img = Image.composite( - img_to_add_back_to_original, prompt.init_image, + img_to_add_back_to_original, mask_for_orig_size, ) log_img(rebuilt_orig_img, "reconstituted original") @@ -359,7 +370,9 @@ def imagine( prompt=prompt, upscaled_img=upscaled_img, is_nsfw=is_nsfw_img, - modified_original_img=rebuilt_orig_img, + modified_original=rebuilt_orig_img, + mask_binary=mask_image_orig, + mask_grayscale=mask_grayscale, ) diff --git a/imaginairy/enhancers/clip_masking.py b/imaginairy/enhancers/clip_masking.py index 06b40e6..42dc8b4 100644 --- a/imaginairy/enhancers/clip_masking.py +++ b/imaginairy/enhancers/clip_masking.py @@ -1,8 +1,11 @@ from functools import lru_cache from typing import Optional, Sequence +import cv2 +import numpy as np import PIL.Image import torch +from kornia.filters import median_blur from torchvision import transforms from imaginairy.img_log import log_img @@ -41,10 +44,30 @@ def get_img_mask( mask_cache = get_img_masks(img, descriptions) mask = parsed_mask.apply_masks(mask_cache) log_img(mask, "combined mask") + + # try to blur the square shaped artifacts somewhat + mask = median_blur(mask.unsqueeze(dim=0).unsqueeze(dim=0), (11, 11)).squeeze() + log_img(mask, "median blurred") + + kernel = np.ones((5, 5), np.uint8) + mask_g = mask.clone() + + # trial and error shows 0.5 threshold has the best "shape" if threshold is not None: - mask[mask < threshold] = 0 - mask[mask >= threshold] = 1 - return transforms.ToPILImage()(mask) + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + log_img(mask, f"mask threshold {0.5}") + + mask_np = mask.cpu().numpy() + smoother_strength = 5 + # grow the mask area to make sure we've masked the thing we care about + for _ in range(smoother_strength): + mask_np = cv2.dilate(mask_np, kernel) + # todo: add an outer blur (not gaussian) + mask = torch.from_numpy(mask_np) + log_img(mask, "mask after closing (dilation then erosion)") + + return transforms.ToPILImage()(mask), transforms.ToPILImage()(mask_g) def get_img_masks(img, mask_descriptions: Sequence[str]): @@ -66,7 +89,6 @@ def get_img_masks(img, mask_descriptions: Sequence[str]): img.repeat(len(mask_descriptions), 1, 1, 1), mask_descriptions )[0] preds = transforms.Resize(orig_size)(preds) - preds = transforms.GaussianBlur(kernel_size=9)(preds) preds = [torch.sigmoid(p[0]) for p in preds] diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 084ddf1..9617a65 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -359,8 +359,12 @@ class DDIMSampler: assert orig_latent is not None xdec_orig = self.model.q_sample(orig_latent, ts) log_latent(xdec_orig, "xdec_orig") - log_latent(xdec_orig * mask, "masked_xdec_orig") - x_dec = xdec_orig * mask + (1.0 - mask) * x_dec + # this helps prevent the weird disjointed images that can happen with masking + hint_strength = 0.8 + xdec_orig_with_hints = ( + xdec_orig * (1 - hint_strength) + orig_latent * hint_strength + ) + x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec log_latent(x_dec, "x_dec") x_dec, pred_x0 = self.p_sample_ddim( diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index e177b5c..493c9b1 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -388,6 +388,7 @@ class PLMSSampler: temperature=1.0, mask=None, orig_latent=None, + noise=None, ): timesteps = self.ddim_timesteps[:t_start] @@ -398,7 +399,15 @@ class PLMSSampler: iterator = tqdm(time_range, desc="PLMS altering image", total=total_steps) x_dec = x_latent old_eps = [] - + log_latent(x_dec, "x_dec") + + # not sure what the downside of using the same noise throughout the process would be... + # seems to work fine. maybe it runs faster? + noise = ( + torch.randn_like(x_dec, device="cpu").to(x_dec.device) + if noise is None + else noise + ) for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full( @@ -413,10 +422,14 @@ class PLMSSampler: if mask is not None: assert orig_latent is not None - xdec_orig = self.model.q_sample(orig_latent, ts) + xdec_orig = self.model.q_sample(orig_latent, ts, noise) log_latent(xdec_orig, "xdec_orig") - log_latent(xdec_orig * mask, "masked_xdec_orig") - x_dec = xdec_orig * mask + (1.0 - mask) * x_dec + # this helps prevent the weird disjointed images that can happen with masking + hint_strength = 0.8 + xdec_orig_with_hints = ( + xdec_orig * (1 - hint_strength) + orig_latent * hint_strength + ) + x_dec = xdec_orig_with_hints * mask + (1.0 - mask) * x_dec log_latent(x_dec, "x_dec") x_dec, pred_x0, e_t = self.p_sample_plms( @@ -446,5 +459,6 @@ class PLMSSampler: img_callback(pred_x0, "pred_x0") log_latent(x_dec, f"x_dec {i}") + log_latent(x_dec, f"e_t {i}") log_latent(pred_x0, f"pred_x0 {i}") return x_dec diff --git a/imaginairy/schema.py b/imaginairy/schema.py index 15be65a..1ed797b 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -138,7 +138,6 @@ class ImaginePrompt: self.mask_modify_original = mask_modify_original self.tile_mode = tile_mode - @property def prompt_text(self): if len(self.prompts) == 1: @@ -186,7 +185,7 @@ class ImagineResult: prompt: ImaginePrompt, is_nsfw, upscaled_img=None, - modified_original_img=None, + modified_original=None, mask_binary=None, mask_grayscale=None, ): @@ -197,8 +196,8 @@ class ImagineResult: if upscaled_img: self.images["upscaled"] = upscaled_img - if modified_original_img: - self.images["modified_original"] = modified_original_img + if modified_original: + self.images["modified_original"] = modified_original if mask_binary: self.images["mask_binary"] = mask_binary diff --git a/tests/test_enhancers.py b/tests/test_enhancers.py index a26b645..c108272 100644 --- a/tests/test_enhancers.py +++ b/tests/test_enhancers.py @@ -40,8 +40,15 @@ def test_clip_masking(): "*1", "*10", ]: - pred = get_img_mask(img, f"(head OR face){{{mask_modifier}}}") - pred.save(f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}.png") + pred_bin, pred_grayscale = get_img_mask( + img, f"(head OR face){{{mask_modifier}}}", threshold=0.1 + ) + pred_grayscale.save( + f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_g.png" + ) + pred_bin.save( + f"{TESTS_FOLDER}/test_output/earring_mask_{mask_modifier}_bin.png" + ) prompt = ImaginePrompt( "professional photo of a woman", @@ -57,8 +64,9 @@ def test_clip_masking(): ) result = next(imagine(prompt)) - result.modified_original_img.save( - f"{TESTS_FOLDER}/test_output/earring_mask_photo.png" + result.save( + f"{TESTS_FOLDER}/test_output/earring_mask_photo.png", + image_type="modified_original", )