mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
208 lines
6.9 KiB
Python
208 lines
6.9 KiB
Python
"""Functions for image tiling and reconstruction"""
|
|
|
|
# inspired by https://github.com/ProGamerGov/neural-dream/blob/master/neural_dream/dream_tile.py
|
|
# but with all the bugs fixed and lots of simplifications
|
|
# MIT License
|
|
import math
|
|
|
|
import torch
|
|
|
|
|
|
def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
|
b, c, h, w = tile.shape
|
|
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap
|
|
(
|
|
std_top_overlap,
|
|
std_bottom_overlap,
|
|
std_right_overlap,
|
|
std_left_overlap,
|
|
) = std_overlap
|
|
|
|
if "left" in side:
|
|
lin_mask_left = torch.linspace(0, 1, std_left_overlap, device=tile.device)
|
|
if left_overlap > std_left_overlap:
|
|
zeros_mask = torch.zeros(
|
|
left_overlap - std_left_overlap, device=tile.device
|
|
)
|
|
lin_mask_left = (
|
|
torch.cat([zeros_mask, lin_mask_left], 0)
|
|
.repeat(h, 1)
|
|
.repeat(c, 1, 1)
|
|
.unsqueeze(0)
|
|
)
|
|
|
|
if "right" in side:
|
|
lin_mask_right = (
|
|
torch.linspace(1, 0, right_overlap, device=tile.device)
|
|
.repeat(h, 1)
|
|
.repeat(c, 1, 1)
|
|
.unsqueeze(0)
|
|
)
|
|
if "top" in side:
|
|
lin_mask_top = torch.linspace(0, 1, std_top_overlap, device=tile.device)
|
|
if top_overlap > std_top_overlap:
|
|
zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device)
|
|
lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0)
|
|
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(c, 1, 1).unsqueeze(0)
|
|
|
|
if "bottom" in side:
|
|
lin_mask_bottom = (
|
|
torch.linspace(1, 0, std_bottom_overlap, device=tile.device)
|
|
.repeat(w, 1)
|
|
.rot90(3)
|
|
.repeat(c, 1, 1)
|
|
.unsqueeze(0)
|
|
)
|
|
|
|
base_mask = torch.ones_like(tile)
|
|
|
|
if "right" in side:
|
|
base_mask[:, :, :, w - right_overlap :] = (
|
|
base_mask[:, :, :, w - right_overlap :] * lin_mask_right
|
|
)
|
|
if "left" in side:
|
|
base_mask[:, :, :, :left_overlap] = (
|
|
base_mask[:, :, :, :left_overlap] * lin_mask_left
|
|
)
|
|
if "bottom" in side:
|
|
base_mask[:, :, h - bottom_overlap :, :] = (
|
|
base_mask[:, :, h - bottom_overlap :, :] * lin_mask_bottom
|
|
)
|
|
if "top" in side:
|
|
base_mask[:, :, :top_overlap, :] = (
|
|
base_mask[:, :, :top_overlap, :] * lin_mask_top
|
|
)
|
|
return tile * base_mask
|
|
|
|
|
|
def get_tile_coords(d, tile_dim, overlap=0):
|
|
move = int(math.ceil(round(tile_dim * (1 - overlap), 10)))
|
|
c, tile_start, coords = 1, 0, [0]
|
|
while tile_start + tile_dim < d:
|
|
tile_start = move * c
|
|
if tile_start + tile_dim >= d:
|
|
coords.append(d - tile_dim)
|
|
else:
|
|
coords.append(tile_start)
|
|
c += 1
|
|
return coords
|
|
|
|
|
|
def get_tiles(img, tile_coords, tile_size):
|
|
tile_list = []
|
|
for y in tile_coords[0]:
|
|
for x in tile_coords[1]:
|
|
tile = img[:, :, y : y + tile_size[0], x : x + tile_size[1]]
|
|
tile_list.append(tile)
|
|
return tile_list
|
|
|
|
|
|
def final_overlap(tile_coords, tile_size):
|
|
last_row, last_col = len(tile_coords[0]) - 1, len(tile_coords[1]) - 1
|
|
|
|
f_ovlp = [
|
|
(tile_coords[0][last_row - 1] + tile_size[0]) - (tile_coords[0][last_row]),
|
|
(tile_coords[1][last_col - 1] + tile_size[1]) - (tile_coords[1][last_col]),
|
|
]
|
|
return f_ovlp
|
|
|
|
|
|
def add_tiles(tiles, base_img, tile_coords, tile_size, overlap):
|
|
f_ovlp = final_overlap(tile_coords, tile_size)
|
|
h, w = tiles[0].size(2), tiles[0].size(3)
|
|
if f_ovlp[0] == h:
|
|
f_ovlp[0] = 0
|
|
|
|
if f_ovlp[1] == w:
|
|
f_ovlp[1] = 0
|
|
|
|
t = 0
|
|
(column, row) = (0, 0)
|
|
|
|
for y in tile_coords[0]:
|
|
for x in tile_coords[1]:
|
|
mask_sides = ""
|
|
c_overlap = overlap.copy()
|
|
if row == 0:
|
|
mask_sides += "bottom"
|
|
elif 0 < row < len(tile_coords[0]) - 2:
|
|
mask_sides += "bottom,top"
|
|
elif row == len(tile_coords[0]) - 2:
|
|
mask_sides += "bottom,top"
|
|
elif row == len(tile_coords[0]) - 1:
|
|
mask_sides += "top"
|
|
if f_ovlp[0] > 0:
|
|
c_overlap[0] = f_ovlp[0] # Change top overlap
|
|
|
|
if column == 0:
|
|
mask_sides += ",right"
|
|
elif 0 < column < len(tile_coords[1]) - 2:
|
|
mask_sides += ",right,left"
|
|
elif column == len(tile_coords[1]) - 2:
|
|
mask_sides += ",right,left"
|
|
elif column == len(tile_coords[1]) - 1:
|
|
mask_sides += ",left"
|
|
if f_ovlp[1] > 0:
|
|
c_overlap[3] = f_ovlp[1] # Change left overlap
|
|
|
|
# print(f"num-tiles={len(tiles)} t={t}")
|
|
# print(
|
|
# f"mask_tile: tile.shape={tiles[t].shape}, overlap={c_overlap}, side={mask_sides} col={column}, row={row}"
|
|
# )
|
|
|
|
tile = mask_tile(tiles[t], c_overlap, std_overlap=overlap, side=mask_sides)
|
|
# torch_img_to_pillow_img(tile).show()
|
|
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] = (
|
|
base_img[:, :, y : y + tile_size[0], x : x + tile_size[1]] + tile
|
|
)
|
|
# torch_img_to_pillow_img(base_img).show()
|
|
t += 1
|
|
column += 1
|
|
|
|
row += 1 # noqa
|
|
# if row >= 2:
|
|
# exit()
|
|
column = 0
|
|
return base_img
|
|
|
|
|
|
def tile_setup(tile_size, overlap_percent, base_size):
|
|
if not isinstance(tile_size, (tuple, list)):
|
|
tile_size = (tile_size, tile_size)
|
|
if not isinstance(overlap_percent, (tuple, list)):
|
|
overlap_percent = (overlap_percent, overlap_percent)
|
|
if min(tile_size) < 1:
|
|
raise ValueError("tile_size must be at least 1")
|
|
|
|
if max(overlap_percent) > 0.5:
|
|
raise ValueError("overlap_percent must not be greater than 0.5")
|
|
|
|
x_coords = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1])
|
|
y_coords = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0])
|
|
y_ovlp = int(math.floor(round(tile_size[0] * overlap_percent[0], 10)))
|
|
x_ovlp = int(math.floor(round(tile_size[1] * overlap_percent[1], 10)))
|
|
if len(x_coords) == 1:
|
|
x_ovlp = 0
|
|
if len(y_coords) == 1:
|
|
y_ovlp = 0
|
|
|
|
return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp]
|
|
|
|
|
|
def tile_image(img, tile_size, overlap_percent):
|
|
tile_coords, tile_size, _ = tile_setup(
|
|
tile_size, overlap_percent, (img.size(2), img.size(3))
|
|
)
|
|
|
|
return get_tiles(img, tile_coords, tile_size)
|
|
|
|
|
|
def rebuild_image(tiles, base_img, tile_size, overlap_percent):
|
|
if len(tiles) == 1:
|
|
return tiles[0]
|
|
base_img = torch.zeros_like(base_img)
|
|
tile_coords, tile_size, overlap = tile_setup(
|
|
tile_size, overlap_percent, (base_img.size(2), base_img.size(3))
|
|
)
|
|
return add_tiles(tiles, base_img, tile_coords, tile_size, overlap)
|