imaginAIry/tests/test_feather_tile.py

131 lines
3.9 KiB
Python

import itertools
import pytest
from imaginairy import LazyLoadingImage
from imaginairy.feather_tile import rebuild_image, tile_image, tile_setup
from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img
from tests import TESTS_FOLDER
img_ratios = [0.2, 0.242, 0.3, 0.33333333, 0.5, 0.75, 1, 4 / 3.0, 16 / 9.0, 2, 21 / 9.0]
pcts = [
0,
0.09,
0.1,
0.2,
0.25,
0.3,
1 / 3,
0.4,
0.5,
0.6,
0.7,
0.75,
0.8,
0.9,
1.0,
]
initial_sizes = [512]
flip = [True, False]
cases = [
(1, 256, 0),
(1, 256, 0.125),
(1, 256, 0.25),
(1, 256, 0.5),
(1, 128, 0),
(1, 128, 0.125),
(1, 128, 0.25),
(1, 128, 0.5),
(1, 512, 0),
(0.2, 46, 0.09),
(0.2, 46, 0.1),
(0.242, 46, 0.2),
(0.2, 51, 1 / 3.0),
(0.2, 102, 0.09), # tile size same as width of image
]
@pytest.mark.parametrize(("img_ratio", "tile_size", "overlap_pct"), cases)
def test_feather_tile_simple(img_ratio, tile_size, overlap_pct):
img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
)
img = img[:, :, : img.shape[2], : int(img.shape[3] * img_ratio)]
img_sum = img.sum()
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_pct)
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_pct, (img.size(2), img.size(3))
)
# print(
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
# )
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_pct
)
assert rebuilt.shape == img.shape
diff = abs(float(rebuilt.sum()) - float(img_sum))
if diff >= 1:
torch_img_to_pillow_img(img).show()
torch_img_to_pillow_img(rebuilt).show()
torch_img_to_pillow_img(rebuilt - img).show()
assert diff < 1
def test_feather_tile_brute():
source_img = pillow_img_to_torch_image(
LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bowl_of_fruit.jpg")
)
def tile_untile(img, tile_size, overlap_percent):
img_sum = img.sum()
tiles = tile_image(img, tile_size=tile_size, overlap_percent=overlap_percent)
tile_coords, tile_size, overlap = tile_setup(
tile_size, overlap_percent, (img.size(2), img.size(3))
)
# print(
# f"tile_coords={tile_coords}, tile_size={tile_size}, overlap={overlap}, img.shape={img.shape}"
# )
rebuilt = rebuild_image(
tiles, base_img=img, tile_size=tile_size, overlap_percent=overlap_percent
)
assert rebuilt.shape == img.shape
diff = abs(float(rebuilt.sum()) - float(img_sum))
if diff > 1:
torch_img_to_pillow_img(img).show()
torch_img_to_pillow_img(rebuilt).show()
torch_img_to_pillow_img((rebuilt - img) * 20).show()
else:
pass
# print(
# f"{status}: img:{img.shape} tile_size={tile_size} overlap_percent={overlap_percent} diff={diff}"
# )
assert diff < 1
for tile_size_pct, overlap_percent, img_ratio, flip_ratio in itertools.product(
pcts, pcts, img_ratios, flip
):
if flip_ratio:
img = source_img.clone()[:, :, :, : int(source_img.shape[3] * img_ratio)]
else:
img = source_img.clone()[:, :, : int(source_img.shape[2] * img_ratio), :]
tile_size = int(source_img.shape[3] * tile_size_pct)
if not tile_size:
continue
if overlap_percent >= 0.5:
continue
# print(
# f"img_ratio={img_ratio}, tile_size_pct={tile_size_pct}, overlap_percent={overlap_percent}, tile_size={tile_size} img.shape={img.shape}"
# )
tile_untile(img, tile_size=tile_size, overlap_percent=overlap_percent)
del img
# tile_untile(img, tile_size=256, overlap_percent=0.25)