feature: 🎉 outpainting

pull/167/head
Bryce 1 year ago committed by Bryce Drennan
parent 993b039d7b
commit 81f294216b

@ -131,6 +131,17 @@ Use depth maps for amazing "translations" of existing images.
<img src="assets/pearl_depth_2.jpg" height="512">
<img src="assets/pearl_depth_3.jpg" height="512">
### Outpainting
Given a starting image, one can generate it's "surroundings".
Example:
`imagine --init-image pearl-earring.jpg --init-image-strength 0 --outpaint all250,up0,down600 "woman standing"`
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/data/girl_with_a_pearl_earring.jpg" height="256"> ➡️
<img src="https://raw.githubusercontent.com/brycedrennan/imaginAIry/master/tests/expected_output/test_outpainting_outpaint_.png" height="256">
### Prompt Expansion
You can use `{}` to randomly pull values from lists. A list of values separated by `|`
and enclosed in `{ }` will be randomly drawn from in a non-repeating fashion. Values that are surrounded by `_ _` will
@ -241,6 +252,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**7.5.0**
- feature: 🎉 outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400`
**7.4.3**
- fix: handle old pytorch lightning imports with a graceful failure (fixes #161)
- fix: handle failed image generations better (fixes #83)

@ -3,7 +3,6 @@ import os
import re
import numpy as np
import PIL
import torch
import torch.nn
from einops import rearrange, repeat
@ -23,6 +22,7 @@ from imaginairy.log_utils import (
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.modules.midas.utils import AddMiDaS
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers.base import NoiseSchedule, noise_an_image
@ -136,7 +136,9 @@ def imagine(
weights_location=prompt.model,
config_path=prompt.model_config_path,
half_mode=half_mode,
for_inpainting=prompt.mask_image or prompt.mask_prompt,
for_inpainting=prompt.mask_image
or prompt.mask_prompt
or prompt.outpaint,
)
has_depth_channel = hasattr(model, "depth_stage_key")
with ImageLoggingContext(
@ -174,28 +176,31 @@ def imagine(
sampler = SamplerCls(model)
mask = mask_image = mask_image_orig = mask_grayscale = None
t_enc = init_latent = init_latent_noised = None
starting_image = None
if prompt.init_image:
starting_image = prompt.init_image
generation_strength = 1 - prompt.init_image_strength
t_enc = int(prompt.steps * generation_strength)
try:
init_image = pillow_fit_image_within(
prompt.init_image,
max_height=prompt.height,
max_width=prompt.width,
)
except PIL.UnidentifiedImageError:
logger.warning(f"Could not load image: {prompt.init_image}")
continue
init_image_t = pillow_img_to_torch_image(init_image)
if prompt.mask_prompt:
mask_image, mask_grayscale = get_img_mask(
init_image, prompt.mask_prompt, threshold=0.1
starting_image, prompt.mask_prompt, threshold=0.1
)
elif prompt.mask_image:
mask_image = prompt.mask_image.convert("L")
if prompt.outpaint:
outpaint_kwargs = outpaint_arg_str_parse(prompt.outpaint)
starting_image, mask_image = prepare_image_for_outpaint(
starting_image, mask_image, **outpaint_kwargs
)
init_image = pillow_fit_image_within(
starting_image,
max_height=prompt.height,
max_width=prompt.width,
)
if mask_image is not None:
mask_image = pillow_fit_image_within(
mask_image,
max_height=prompt.height,
@ -203,7 +208,6 @@ def imagine(
convert="L",
)
if mask_image is not None:
log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
@ -228,7 +232,7 @@ def imagine(
mask = mask[None, None]
mask = torch.from_numpy(mask)
mask = mask.to(get_device())
init_image_t = pillow_img_to_torch_image(init_image)
init_image_t = init_image_t.to(get_device())
init_latent = model.get_first_stage_encoding(
model.encode_first_stage(init_image_t)
@ -266,9 +270,9 @@ def imagine(
}
c_cat = []
depth_image_display = None
if has_depth_channel and prompt.init_image:
if has_depth_channel and starting_image:
midas_model = AddMiDaS()
_init_image_d = np.array(prompt.init_image.convert("RGB"))
_init_image_d = np.array(starting_image.convert("RGB"))
_init_image_d = (
torch.from_numpy(_init_image_d).to(dtype=torch.float32) / 127.5
- 1.0
@ -414,20 +418,20 @@ def imagine(
if (
prompt.mask_modify_original
and mask_image_orig
and prompt.init_image
and starting_image
):
img_to_add_back_to_original = (
upscaled_img if upscaled_img else img
)
img_to_add_back_to_original = (
img_to_add_back_to_original.resize(
prompt.init_image.size,
starting_image.size,
resample=Image.Resampling.LANCZOS,
)
)
mask_for_orig_size = mask_image_orig.resize(
prompt.init_image.size,
starting_image.size,
resample=Image.Resampling.LANCZOS,
)
mask_for_orig_size = mask_for_orig_size.filter(
@ -436,7 +440,7 @@ def imagine(
log_img(mask_for_orig_size, "mask for original image size")
rebuilt_orig_img = Image.composite(
prompt.init_image,
starting_image,
img_to_add_back_to_original,
mask_for_orig_size,
)

@ -171,6 +171,17 @@ logger = logging.getLogger(__name__)
is_flag=True,
help="After the inpainting is done, apply the changes to a copy of the original image.",
)
@click.option(
"--outpaint",
help=(
"Specify in what directions to expand the image. Values will be snapped such that output image size is multiples of 64. Examples\n"
" `--outpaint up10,down300,left50,right50`\n"
" `--outpaint u10,d300,l50,r50`\n"
" `--outpaint all200`\n"
" `--outpaint a200`\n"
),
default="",
)
@click.option(
"--caption",
default=False,
@ -232,6 +243,7 @@ def imagine_cmd(
mask_prompt,
mask_mode,
mask_modify_original,
outpaint,
caption,
precision,
model_weights_path,
@ -292,6 +304,7 @@ def imagine_cmd(
mask_prompt=mask_prompt,
mask_mode=mask_mode,
mask_modify_original=mask_modify_original,
outpaint=outpaint,
upscale=upscale,
fix_faces=fix_faces,
fix_faces_fidelity=fix_faces_fidelity,

@ -259,6 +259,14 @@ def disable_transformers_custom_logging():
def disable_pytorch_lighting_custom_logging():
try:
from pytorch_lightning.utilities.seed import log # noqa
log.setLevel(logging.NOTSET)
log.handlers = []
log.propagate = False
except ImportError:
pass
pytorch_logger.setLevel(logging.NOTSET)

@ -0,0 +1,118 @@
import re
from PIL import Image, ImageDraw
def prepare_image_for_outpaint(
img, mask=None, up=None, down=None, left=None, right=None, _all=0, snap_multiple=64
):
up = up if up is not None else _all
down = down if down is not None else _all
left = left if left is not None else _all
right = right if right is not None else _all
lft_pct = left / (left + right)
rgt_pct = right / (left + right)
up_pct = up / (up + down)
dwn_pct = down / (up + down)
new_width = round((img.width + left + right) / snap_multiple) * snap_multiple
new_height = round((img.height + up + down) / snap_multiple) * snap_multiple
height_addition = max(new_height - img.height, 0)
width_addition = max(new_width - img.width, 0)
up = int(round(height_addition * up_pct))
down = int(round(height_addition * dwn_pct))
left = int(round(width_addition * lft_pct))
right = int(round(width_addition * rgt_pct))
expanded_image = Image.new(
"RGB", (img.width + left + right, img.height + up + down), (0, 0, 0)
)
expanded_image.paste(img, (left, up))
# extend border pixels outward, this helps prevents lines at the boundary because masks getting reduced to
# 64x64 latent space can cause som inaccuracies
if up > 0:
expanded_image.paste(
img.crop((0, 0, img.width, 1)).resize((expanded_image.width, up)),
(0, 0),
)
expanded_image.paste(
img.crop((0, 0, img.width, 1)).resize((img.width, up)),
(left, 0),
)
if down > 0:
expanded_image.paste(
img.crop((0, img.height - 1, img.width, img.height)).resize(
(expanded_image.width, down)
),
(0, expanded_image.height - down),
)
expanded_image.paste(
img.crop((0, img.height - 1, img.width, img.height)).resize(
(img.width, down)
),
(left, expanded_image.height - down),
)
if left > 0:
expanded_image.paste(
img.crop((0, 0, 1, img.height)).resize((left, expanded_image.height)),
(0, 0),
)
expanded_image.paste(
img.crop((0, 0, 1, img.height)).resize((left, img.height)),
(0, up),
)
if right > 0:
expanded_image.paste(
img.crop((img.width - 1, 0, img.width, img.height)).resize(
(right, expanded_image.height)
),
(expanded_image.width - right, 0),
)
expanded_image.paste(
img.crop((img.width - 1, 0, img.width, img.height)).resize(
(right, img.height)
),
(expanded_image.width - right, up),
)
# create a mask for the new boundaries
expanded_mask = Image.new("L", (expanded_image.width, expanded_image.height), 255)
if mask is None:
draw = ImageDraw.Draw(expanded_mask)
draw.rectangle(
(left, up, left + img.width, up + img.height), fill="black", outline="black"
)
else:
expanded_mask.paste(mask, (left, up))
return expanded_image, expanded_mask
def outpaint_arg_str_parse(arg_str):
arg_pattern = re.compile(r"([A-Z]+)(\d+)")
args = arg_str.upper().split(",")
valid_directions = ["up", "down", "left", "right", "all"]
valid_direction_chars = {c[0]: c for c in valid_directions}
kwargs = {}
for arg in args:
match = arg_pattern.match(arg)
if not match:
raise ValueError(f"Invalid outpaint argument '{arg}'")
direction, amount = match.groups()
direction = direction.lower()
if len(direction) == 1:
if direction not in valid_direction_chars:
raise ValueError(f"Invalid outpaint direction '{direction}'")
direction = valid_direction_chars[direction]
elif direction not in valid_directions:
raise ValueError(f"Invalid outpaint direction '{direction}'")
kwargs[direction] = int(amount)
if "all" in kwargs:
kwargs["_all"] = kwargs.pop("all")
return kwargs

@ -102,6 +102,7 @@ class ImaginePrompt:
mask_image=None,
mask_mode=MaskMode.REPLACE,
mask_modify_original=True,
outpaint=None,
seed=None,
steps=None,
height=None,
@ -155,6 +156,7 @@ class ImaginePrompt:
self.mask_image = mask_image
self.mask_mode = mask_mode
self.mask_modify_original = mask_modify_original
self.outpaint = outpaint
self.tile_mode = tile_mode
self.model = model

Binary file not shown.

Before

Width:  |  Height:  |  Size: 323 KiB

After

Width:  |  Height:  |  Size: 325 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 320 KiB

After

Width:  |  Height:  |  Size: 318 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 257 KiB

After

Width:  |  Height:  |  Size: 256 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 254 KiB

After

Width:  |  Height:  |  Size: 252 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 251 KiB

After

Width:  |  Height:  |  Size: 248 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 245 KiB

After

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 259 KiB

After

Width:  |  Height:  |  Size: 259 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 326 KiB

@ -0,0 +1,40 @@
import pytest
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine
from imaginairy.outpaint import outpaint_arg_str_parse
from tests import TESTS_FOLDER
from tests.utils import assert_image_similar_to_expectation
def test_outpainting_outpaint(filename_base_for_outputs):
img = LazyLoadingImage(
filepath=f"{TESTS_FOLDER}/data/girl_with_a_pearl_earring.jpg"
)
prompt = ImaginePrompt(
prompt="woman standing",
init_image=img,
init_image_strength=0,
mask_prompt="background",
outpaint="all250,up0,down600",
mask_mode="replace",
negative_prompt="picture frame, borders, framing, text, writing, watermarks, indoors, advertisement, paper, canvas, stock photo",
steps=20,
seed=542906833,
)
result = list(imagine([prompt]))[0]
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=2800)
outpaint_test_params = [
("A132", {"_all": 132}),
("A132,U50", {"_all": 132, "up": 50}),
("A132,U50,D50", {"_all": 132, "up": 50, "down": 50}),
("a132,u50,d50", {"_all": 132, "up": 50, "down": 50}),
("all50,up20,down600", {"_all": 50, "up": 20, "down": 600}),
]
@pytest.mark.parametrize("arg_str, expected_kwargs", outpaint_test_params)
def test_outpaint_parse_kwargs(arg_str, expected_kwargs):
assert outpaint_arg_str_parse(arg_str) == expected_kwargs
Loading…
Cancel
Save