mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-11-17 09:25:47 +00:00
2aef6089e0
Added a composition stage so large images are more coherent
233 lines
7.5 KiB
Python
233 lines
7.5 KiB
Python
import re
|
|
|
|
import torch
|
|
from PIL import Image, ImageDraw
|
|
from torch import nn
|
|
|
|
from imaginairy.img_utils import torch_img_to_pillow_img
|
|
|
|
|
|
def outpaint_calculations(
|
|
img_width,
|
|
img_height,
|
|
up=None,
|
|
down=None,
|
|
left=None,
|
|
right=None,
|
|
_all=0,
|
|
snap_multiple=8,
|
|
):
|
|
up = up if up is not None else _all
|
|
down = down if down is not None else _all
|
|
left = left if left is not None else _all
|
|
right = right if right is not None else _all
|
|
|
|
lft_pct = left / (left + right) if left + right else 0
|
|
rgt_pct = right / (left + right) if left + right else 0
|
|
up_pct = up / (up + down) if up + down else 0
|
|
dwn_pct = down / (up + down) if up + down else 0
|
|
|
|
new_width = round((img_width + left + right) / snap_multiple) * snap_multiple
|
|
new_height = round((img_height + up + down) / snap_multiple) * snap_multiple
|
|
height_addition = max(new_height - img_height, 0)
|
|
width_addition = max(new_width - img_width, 0)
|
|
up = int(round(height_addition * up_pct))
|
|
down = int(round(height_addition * dwn_pct))
|
|
left = int(round(width_addition * lft_pct))
|
|
right = int(round(width_addition * rgt_pct))
|
|
|
|
return up, down, left, right, new_width, new_height
|
|
|
|
|
|
def prepare_tensor_for_outpaint(
|
|
img, mask=None, up=None, down=None, left=None, right=None, _all=0, snap_multiple=8
|
|
):
|
|
up, down, left, right, new_width, new_height = outpaint_calculations(
|
|
img_width=img.shape[2],
|
|
img_height=img.shape[1],
|
|
up=up,
|
|
down=down,
|
|
left=left,
|
|
right=right,
|
|
_all=_all,
|
|
snap_multiple=snap_multiple,
|
|
)
|
|
|
|
def resize(img_t, h, w):
|
|
new_size = (img_t.shape[0], h, w)
|
|
return nn.functional.interpolate(img_t, size=new_size, mode="nearest")
|
|
|
|
def paste(dst, src, y, x):
|
|
dst[:, y : y + src.shape[1], x : x + src.shape[2]] = src
|
|
|
|
expanded_img = torch.zeros(
|
|
img.shape[0], img.shape[1] + up + down, img.shape[2] + left + right
|
|
)
|
|
expanded_img[:, up : up + img.shape[1], left : left + img.shape[2]] = img
|
|
|
|
# extend border pixels outward, this helps prevents lines at the boundary because masks getting reduced to
|
|
# 64x64 latent space can cause some inaccuracies
|
|
|
|
if up > 0:
|
|
top_row = img[:, 0, :]
|
|
paste(expanded_img, resize(top_row, h=up, w=expanded_img.shape[2]), y=0, x=0)
|
|
paste(expanded_img, resize(top_row, h=up, w=img.shape[2]), y=0, x=left)
|
|
|
|
if down > 0:
|
|
bottom_row = img[:, -1, :]
|
|
paste(
|
|
expanded_img,
|
|
resize(bottom_row, h=down, w=expanded_img.shape[2]),
|
|
y=expanded_img.shape[1] - down,
|
|
x=0,
|
|
)
|
|
paste(
|
|
expanded_img,
|
|
resize(bottom_row, h=down, w=img.shape[2]),
|
|
y=expanded_img.shape[1] - down,
|
|
x=left,
|
|
)
|
|
|
|
if left > 0:
|
|
left_column = img[:, :, 0]
|
|
paste(
|
|
expanded_img, resize(left_column, h=expanded_img.shape[1], w=left), y=0, x=0
|
|
)
|
|
paste(expanded_img, resize(left_column, h=img.shape[1], w=left), y=up, x=0)
|
|
|
|
if right > 0:
|
|
right_column = img[:, :, -1]
|
|
paste(
|
|
expanded_img,
|
|
resize(right_column, h=expanded_img.shape[1], w=right),
|
|
y=0,
|
|
x=expanded_img.shape[2] - right,
|
|
)
|
|
paste(
|
|
expanded_img,
|
|
resize(right_column, h=img.shape[1], w=right),
|
|
y=up,
|
|
x=expanded_img.shape[2] - right,
|
|
)
|
|
|
|
# create a mask for the new boundaries
|
|
expanded_mask = torch.zeros_like(expanded_img)
|
|
|
|
if mask is None:
|
|
# set to black
|
|
expanded_mask[:, up : up + img.shape[1], left : left + img.shape[2]] = 1
|
|
else:
|
|
expanded_mask[:, up : up + mask.shape[1], left : left + mask.shape[2]] = mask
|
|
|
|
return expanded_img, expanded_mask
|
|
|
|
|
|
def prepare_image_for_outpaint(
|
|
img, mask=None, up=None, down=None, left=None, right=None, _all=0, snap_multiple=8
|
|
):
|
|
up, down, left, right, new_width, new_height = outpaint_calculations(
|
|
img_width=img.width,
|
|
img_height=img.height,
|
|
up=up,
|
|
down=down,
|
|
left=left,
|
|
right=right,
|
|
_all=_all,
|
|
snap_multiple=snap_multiple,
|
|
)
|
|
ran_img_t = torch.randn((1, 3, new_height, new_width), device="cpu")
|
|
expanded_image = torch_img_to_pillow_img(ran_img_t)
|
|
# expanded_image = Image.new(
|
|
# "RGB", (img.width + left + right, img.height + up + down), (0, 0, 0)
|
|
# )
|
|
expanded_image.paste(img, (left, up))
|
|
|
|
# extend border pixels outward, this helps prevents lines at the boundary because masks getting reduced to
|
|
# 64x64 latent space can cause some inaccuracies
|
|
alpha = 20
|
|
if up > 0:
|
|
top_row = img.crop((0, 0, img.width, 1))
|
|
top_row.putalpha(alpha)
|
|
expanded_image.paste(
|
|
top_row.resize((expanded_image.width, up)),
|
|
(0, 0),
|
|
)
|
|
expanded_image.paste(
|
|
top_row.resize((img.width, up)),
|
|
(left, 0),
|
|
)
|
|
if down > 0:
|
|
bottom_row = img.crop((0, img.height - 1, img.width, img.height))
|
|
bottom_row.putalpha(alpha)
|
|
expanded_image.paste(
|
|
bottom_row.resize((expanded_image.width, down)),
|
|
(0, expanded_image.height - down),
|
|
)
|
|
expanded_image.paste(
|
|
bottom_row.resize((img.width, down)),
|
|
(left, expanded_image.height - down),
|
|
)
|
|
if left > 0:
|
|
left_column = img.crop((0, 0, 1, img.height))
|
|
left_column.putalpha(alpha)
|
|
expanded_image.paste(
|
|
left_column.resize((left, expanded_image.height)),
|
|
(0, 0),
|
|
)
|
|
expanded_image.paste(
|
|
left_column.resize((left, img.height)),
|
|
(0, up),
|
|
)
|
|
if right > 0:
|
|
right_column = img.crop((img.width - 1, 0, img.width, img.height))
|
|
right_column.putalpha(alpha)
|
|
expanded_image.paste(
|
|
right_column.resize((right, expanded_image.height)),
|
|
(expanded_image.width - right, 0),
|
|
)
|
|
expanded_image.paste(
|
|
img.crop((img.width - 1, 0, img.width, img.height)).resize(
|
|
(right, img.height)
|
|
),
|
|
(expanded_image.width - right, up),
|
|
)
|
|
|
|
# create a mask for the new boundaries
|
|
expanded_mask = Image.new("L", (expanded_image.width, expanded_image.height), 255)
|
|
if mask is None:
|
|
draw = ImageDraw.Draw(expanded_mask)
|
|
draw.rectangle(
|
|
(left, up, left + img.width, up + img.height), fill="black", outline="black"
|
|
)
|
|
else:
|
|
expanded_mask.paste(mask, (left, up))
|
|
|
|
return expanded_image, expanded_mask
|
|
|
|
|
|
def outpaint_arg_str_parse(arg_str):
|
|
arg_pattern = re.compile(r"([A-Z]+)(\d+)")
|
|
|
|
args = arg_str.upper().split(",")
|
|
valid_directions = ["up", "down", "left", "right", "all"]
|
|
valid_direction_chars = {c[0]: c for c in valid_directions}
|
|
kwargs = {}
|
|
for arg in args:
|
|
match = arg_pattern.match(arg)
|
|
if not match:
|
|
raise ValueError(f"Invalid outpaint argument '{arg}'")
|
|
direction, amount = match.groups()
|
|
direction = direction.lower()
|
|
if len(direction) == 1:
|
|
if direction not in valid_direction_chars:
|
|
raise ValueError(f"Invalid outpaint direction '{direction}'")
|
|
direction = valid_direction_chars[direction]
|
|
elif direction not in valid_directions:
|
|
raise ValueError(f"Invalid outpaint direction '{direction}'")
|
|
kwargs[direction] = int(amount)
|
|
|
|
if "all" in kwargs:
|
|
kwargs["_all"] = kwargs.pop("all")
|
|
|
|
return kwargs
|