mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
test: add some autoencoder tests
the fold-unfold encoding/decoding looks like it's slower and has worse seams than the sliced feathering approach
This commit is contained in:
parent
1ceb17c083
commit
1563e0b871
@ -56,6 +56,21 @@ class AutoencoderKL(pl.LightningModule):
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
ks = 128
|
||||
stride = 64
|
||||
vqf = 8
|
||||
self.split_input_params = {
|
||||
"ks": (ks, ks),
|
||||
"stride": (stride, stride),
|
||||
"vqf": vqf,
|
||||
"patch_distributed_vq": True,
|
||||
"tie_braker": False,
|
||||
"clip_max_weight": 0.5,
|
||||
"clip_min_weight": 0.01,
|
||||
"clip_max_tie_weight": 0.5,
|
||||
"clip_min_tie_weight": 0.01,
|
||||
}
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=None):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
@ -89,10 +104,12 @@ class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def encode(self, x):
|
||||
return self.encode_sliced(x)
|
||||
# h = self.encoder(x)
|
||||
# moments = self.quant_conv(h)
|
||||
# posterior = DiagonalGaussianDistribution(moments)
|
||||
# return posterior.sample()
|
||||
|
||||
def encode_all_at_once(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior.mode()
|
||||
|
||||
def encode_sliced(self, x, chunk_size=128 * 8):
|
||||
"""
|
||||
@ -123,6 +140,60 @@ class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
return final_tensor
|
||||
|
||||
def encode_with_folds(self, x):
|
||||
bs, nc, h, w = x.shape
|
||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||
df = self.split_input_params["vqf"]
|
||||
|
||||
if h > ks[0] * df or w > ks[1] * df:
|
||||
self.split_input_params["original_image_size"] = x.shape[-2:]
|
||||
orig_shape = x.shape
|
||||
|
||||
if ks[0] > h // df or ks[1] > w // df:
|
||||
ks = (min(ks[0], h // df), min(ks[1], w // df))
|
||||
logger.debug(f"reducing Kernel to {ks}")
|
||||
|
||||
if stride[0] > h // df or stride[1] > w // df:
|
||||
stride = (min(stride[0], h // df), min(stride[1], w // df))
|
||||
logger.debug("reducing stride")
|
||||
bottom_pad = math.ceil(h / (ks[0] * df)) * (ks[0] * df) - h
|
||||
right_pad = math.ceil(w / (ks[1] * df)) * (ks[1] * df) - w
|
||||
|
||||
padded_x = torch.zeros(
|
||||
(bs, nc, h + bottom_pad, w + right_pad), device=x.device
|
||||
)
|
||||
padded_x[:, :, :h, :w] = x
|
||||
x = padded_x
|
||||
|
||||
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
||||
x, ks, stride, df=df
|
||||
)
|
||||
z = unfold(x) # (bn, nc * prod(**ks), L)
|
||||
# Reshape to img shape
|
||||
z = z.view(
|
||||
(z.shape[0], -1, ks[0] * df, ks[1] * df, z.shape[-1])
|
||||
) # (bn, nc, ks[0], ks[1], L )
|
||||
|
||||
output_list = [
|
||||
self.encode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
|
||||
]
|
||||
|
||||
o = torch.stack(output_list, axis=-1)
|
||||
o = o * weighting
|
||||
|
||||
# Reverse reshape to img shape
|
||||
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
||||
# stitch crops together
|
||||
encoded = fold(o)
|
||||
encoded = encoded / normalization
|
||||
# trim off padding
|
||||
encoded = encoded[:, :, : orig_shape[2] // df, : orig_shape[3] // df]
|
||||
|
||||
return encoded
|
||||
|
||||
return self.encode_all_at_once(x)
|
||||
|
||||
def decode(self, z):
|
||||
try:
|
||||
return self.decode_all_at_once(z)
|
||||
@ -167,6 +238,56 @@ class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
return final_tensor
|
||||
|
||||
def decode_with_folds(self, z):
|
||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||
uf = self.split_input_params["vqf"]
|
||||
orig_shape = z.shape
|
||||
bs, nc, h, w = z.shape
|
||||
bottom_pad = math.ceil(h / ks[0]) * ks[0] - h
|
||||
right_pad = math.ceil(w / ks[1]) * ks[1] - w
|
||||
|
||||
# pad the latent such that the unfolding will cover the whole image
|
||||
padded_z = torch.zeros((bs, nc, h + bottom_pad, w + right_pad), device=z.device)
|
||||
padded_z[:, :, :h, :w] = z
|
||||
z = padded_z
|
||||
|
||||
bs, nc, h, w = z.shape
|
||||
if ks[0] > h or ks[1] > w:
|
||||
ks = (min(ks[0], h), min(ks[1], w))
|
||||
logger.debug("reducing Kernel")
|
||||
|
||||
if stride[0] > h or stride[1] > w:
|
||||
stride = (min(stride[0], h), min(stride[1], w))
|
||||
logger.debug("reducing stride")
|
||||
|
||||
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
||||
z, ks, stride, uf=uf
|
||||
)
|
||||
|
||||
z = unfold(z) # (bn, nc * prod(**ks), L)
|
||||
# 1. Reshape to img shape
|
||||
z = z.view(
|
||||
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
|
||||
) # (bn, nc, ks[0], ks[1], L )
|
||||
|
||||
# 2. apply model loop over last dim
|
||||
|
||||
output_list = [
|
||||
self.decode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
|
||||
]
|
||||
|
||||
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
|
||||
o = o * weighting
|
||||
# Reverse 1. reshape to img shape
|
||||
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
||||
# stitch crops together
|
||||
decoded = fold(o)
|
||||
decoded = decoded / normalization # norm is shape (1, 1, h, w)
|
||||
# trim off padding
|
||||
decoded = decoded[:, :, : orig_shape[2] * 8, : orig_shape[3] * 8]
|
||||
return decoded
|
||||
|
||||
def forward(self, input, sample_posterior=True): # noqa
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
@ -321,6 +442,143 @@ class AutoencoderKL(pl.LightningModule):
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def delta_border(self, h, w):
|
||||
"""
|
||||
:param h: height
|
||||
:param w: width
|
||||
:return: normalized distance to image border,
|
||||
wtith min distance = 0 at border and max dist = 0.5 at image center
|
||||
"""
|
||||
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
||||
arr = self.meshgrid(h, w) / lower_right_corner
|
||||
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
||||
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
||||
edge_dist = torch.min(
|
||||
torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
|
||||
)[0]
|
||||
return edge_dist
|
||||
|
||||
def get_weighting(self, h, w, Ly, Lx, device):
|
||||
weighting = self.delta_border(h, w)
|
||||
weighting = torch.clip(
|
||||
weighting,
|
||||
self.split_input_params["clip_min_weight"],
|
||||
self.split_input_params["clip_max_weight"],
|
||||
)
|
||||
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
||||
|
||||
if self.split_input_params["tie_braker"]:
|
||||
L_weighting = self.delta_border(Ly, Lx)
|
||||
L_weighting = torch.clip(
|
||||
L_weighting,
|
||||
self.split_input_params["clip_min_tie_weight"],
|
||||
self.split_input_params["clip_max_tie_weight"],
|
||||
)
|
||||
|
||||
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
||||
weighting = weighting * L_weighting
|
||||
return weighting
|
||||
|
||||
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):
|
||||
"""
|
||||
:param x: img of size (bs, c, h, w)
|
||||
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
||||
"""
|
||||
bs, nc, h, w = x.shape # noqa
|
||||
|
||||
# number of crops in image
|
||||
Ly = (h - kernel_size[0]) // stride[0] + 1
|
||||
Lx = (w - kernel_size[1]) // stride[1] + 1
|
||||
|
||||
if uf == 1 and df == 1:
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
||||
|
||||
weighting = self.get_weighting(
|
||||
kernel_size[0], kernel_size[1], Ly, Lx, x.device
|
||||
).to(x.dtype)
|
||||
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
||||
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
||||
|
||||
elif uf > 1 and df == 1:
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
unfold = torch.nn.Unfold(**fold_params)
|
||||
|
||||
fold_params2 = {
|
||||
"kernel_size": (kernel_size[0] * uf, kernel_size[0] * uf),
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": (stride[0] * uf, stride[1] * uf),
|
||||
}
|
||||
fold = torch.nn.Fold(
|
||||
output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
|
||||
)
|
||||
|
||||
weighting = self.get_weighting(
|
||||
kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
|
||||
).to(x.dtype)
|
||||
normalization = fold(weighting).view(
|
||||
1, 1, h * uf, w * uf
|
||||
) # normalizes the overlap
|
||||
weighting = weighting.view(
|
||||
(1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
|
||||
)
|
||||
|
||||
elif df > 1 and uf == 1:
|
||||
Ly = (h - (kernel_size[0] * df)) // (stride[0] * df) + 1
|
||||
Lx = (w - (kernel_size[1] * df)) // (stride[1] * df) + 1
|
||||
|
||||
unfold_params = {
|
||||
"kernel_size": (kernel_size[0] * df, kernel_size[1] * df),
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": (stride[0] * df, stride[1] * df),
|
||||
}
|
||||
|
||||
unfold = torch.nn.Unfold(**unfold_params)
|
||||
|
||||
fold_params = {
|
||||
"kernel_size": kernel_size,
|
||||
"dilation": 1,
|
||||
"padding": 0,
|
||||
"stride": stride,
|
||||
}
|
||||
fold = torch.nn.Fold(
|
||||
output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params
|
||||
)
|
||||
|
||||
weighting = self.get_weighting(
|
||||
kernel_size[0], kernel_size[1], Ly, Lx, x.device
|
||||
).to(x.dtype)
|
||||
normalization = fold(weighting).view(
|
||||
1, 1, h // df, w // df
|
||||
) # normalizes the overlap
|
||||
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return fold, unfold, normalization, weighting
|
||||
|
||||
def meshgrid(self, h, w):
|
||||
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
||||
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
||||
|
||||
arr = torch.cat([y, x], dim=-1)
|
||||
return arr
|
||||
|
||||
# def to_rgb(self, x):
|
||||
# assert self.image_key == "segmentation"
|
||||
# if not hasattr(self, "colorize"):
|
||||
|
@ -892,7 +892,7 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
def get_first_stage_encoding(self, encoder_posterior):
|
||||
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
||||
z = encoder_posterior.sample()
|
||||
z = encoder_posterior.mode()
|
||||
elif isinstance(encoder_posterior, torch.Tensor):
|
||||
z = encoder_posterior
|
||||
else:
|
||||
@ -1127,47 +1127,6 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_first_stage(self, x):
|
||||
if (
|
||||
hasattr(self, "split_input_params")
|
||||
and self.split_input_params["patch_distributed_vq"]
|
||||
):
|
||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||
df = self.split_input_params["vqf"]
|
||||
self.split_input_params["original_image_size"] = x.shape[-2:]
|
||||
bs, nc, h, w = x.shape # noqa
|
||||
if ks[0] > h or ks[1] > w:
|
||||
ks = (min(ks[0], h), min(ks[1], w))
|
||||
logger.info("reducing Kernel")
|
||||
|
||||
if stride[0] > h or stride[1] > w:
|
||||
stride = (min(stride[0], h), min(stride[1], w))
|
||||
logger.info("reducing stride")
|
||||
|
||||
fold, unfold, normalization, weighting = self.get_fold_unfold(
|
||||
x, ks, stride, df=df
|
||||
)
|
||||
z = unfold(x) # (bn, nc * prod(**ks), L)
|
||||
# Reshape to img shape
|
||||
z = z.view(
|
||||
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
|
||||
) # (bn, nc, ks[0], ks[1], L )
|
||||
|
||||
output_list = [
|
||||
self.first_stage_model.encode(z[:, :, :, :, i])
|
||||
for i in range(z.shape[-1])
|
||||
]
|
||||
|
||||
o = torch.stack(output_list, axis=-1)
|
||||
o = o * weighting
|
||||
|
||||
# Reverse reshape to img shape
|
||||
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
|
||||
# stitch crops together
|
||||
decoded = fold(o)
|
||||
decoded = decoded / normalization
|
||||
return decoded
|
||||
|
||||
return self.first_stage_model.encode(x)
|
||||
|
||||
def shared_step(self, batch, **kwargs):
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 132 KiB |
BIN
tests/data/dog.jpg
Normal file
BIN
tests/data/dog.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
146
tests/modules/test_autoencoders.py
Normal file
146
tests/modules/test_autoencoders.py
Normal file
@ -0,0 +1,146 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.enhancers.upscale_riverwing import upscale_latent
|
||||
from imaginairy.img_utils import (
|
||||
pillow_fit_image_within,
|
||||
pillow_img_to_torch_image,
|
||||
torch_img_to_pillow_img,
|
||||
)
|
||||
from imaginairy.model_manager import get_diffusion_model
|
||||
from imaginairy.utils import get_device
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
strat_combos = [
|
||||
("sliced", "sliced"),
|
||||
("sliced", "all_at_once"),
|
||||
("folds", "folds"),
|
||||
("folds", "all_at_once"),
|
||||
("all_at_once", "all_at_once"),
|
||||
("all_at_once", "sliced"),
|
||||
("all_at_once", "folds"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
|
||||
@pytest.mark.parametrize("encode_strat,decode_strat", strat_combos)
|
||||
def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat):
|
||||
"""Test that encoding and decoding works."""
|
||||
model = get_diffusion_model()
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
|
||||
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
|
||||
img.save(f"{filename_base_for_outputs}_orig.png")
|
||||
img_t = pillow_img_to_torch_image(img).to(get_device())
|
||||
if encode_strat == "all_at_once":
|
||||
latent = model.first_stage_model.encode_all_at_once(img_t) * model.scale_factor
|
||||
elif encode_strat == "folds":
|
||||
latent = model.first_stage_model.encode_with_folds(img_t) * model.scale_factor
|
||||
else:
|
||||
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
|
||||
|
||||
if decode_strat == "all_at_once":
|
||||
decoded_img_t = model.first_stage_model.decode_all_at_once(
|
||||
latent / model.scale_factor
|
||||
)
|
||||
elif decode_strat == "folds":
|
||||
decoded_img_t = model.first_stage_model.decode_with_folds(
|
||||
latent / model.scale_factor
|
||||
)
|
||||
else:
|
||||
decoded_img_t = model.first_stage_model.decode_sliced(
|
||||
latent / model.scale_factor
|
||||
)
|
||||
decoded_img_t = interpolate(decoded_img_t, img_t.shape[-2:])
|
||||
decoded_img = torch_img_to_pillow_img(decoded_img_t)
|
||||
decoded_img.save(f"{filename_base_for_outputs}.png")
|
||||
diff_img = Image.fromarray(np.asarray(img) - np.asarray(decoded_img))
|
||||
diff_img.save(f"{filename_base_for_outputs}_diff.png")
|
||||
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_encode_decode_naive_scale(filename_base_for_outputs):
|
||||
model = get_diffusion_model()
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/dog.jpg")
|
||||
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
|
||||
img.save(f"{filename_base_for_outputs}_orig.png")
|
||||
img_t = pillow_img_to_torch_image(img).to(get_device())
|
||||
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
|
||||
latent = interpolate(latent, scale_factor=2)
|
||||
|
||||
decoded_img_t = model.first_stage_model.decode_sliced(latent / model.scale_factor)
|
||||
decoded_img = torch_img_to_pillow_img(decoded_img_t)
|
||||
decoded_img.save(f"{filename_base_for_outputs}.png")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="experimental")
|
||||
def test_upscale_methods(filename_base_for_outputs, steps):
|
||||
"""
|
||||
compare upscale methods.
|
||||
"""
|
||||
steps = 25
|
||||
model = get_diffusion_model()
|
||||
roi_pcts = (0.7, 0.1, 0.9, 0.3)
|
||||
|
||||
def crop_pct(img, roi_pcts):
|
||||
w, h = img.size
|
||||
roi = (
|
||||
int(w * roi_pcts[0]),
|
||||
int(h * roi_pcts[1]),
|
||||
int(w * roi_pcts[2]),
|
||||
int(h * roi_pcts[3]),
|
||||
)
|
||||
return img.crop(roi)
|
||||
|
||||
def decode(latent):
|
||||
t = model.first_stage_model.decode_sliced(latent / model.scale_factor)
|
||||
return torch_img_to_pillow_img(t)
|
||||
|
||||
img = LazyLoadingImage(
|
||||
filepath=f"{TESTS_FOLDER}/data/010853_1_kdpmpp2m30_PS7.5_portrait_photo_of_a_freckled_woman_[generated].jpg"
|
||||
)
|
||||
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
|
||||
img = crop_pct(img, roi_pcts)
|
||||
|
||||
upscaled = []
|
||||
sampling_methods = [
|
||||
("nearest", Image.Resampling.NEAREST),
|
||||
("bilinear", Image.Resampling.BILINEAR),
|
||||
("bicubic", Image.Resampling.BICUBIC),
|
||||
("lanczos", Image.Resampling.LANCZOS),
|
||||
]
|
||||
for method_name, sample_method in sampling_methods:
|
||||
upscaled.append(
|
||||
(
|
||||
img.resize((img.width * 4, img.height * 4), resample=sample_method),
|
||||
f"{method_name}",
|
||||
)
|
||||
)
|
||||
|
||||
img_t = pillow_img_to_torch_image(img).to(get_device())
|
||||
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
|
||||
|
||||
sharp_latent = upscale_latent(
|
||||
latent, steps=steps, upscale_prompt="high detail, sharp focus, 4k"
|
||||
)
|
||||
sharp_latent = upscale_latent(
|
||||
sharp_latent, steps=steps, upscale_prompt="high detail, sharp focus, 4k"
|
||||
)
|
||||
upscaled.append((decode(sharp_latent), "riverwing-upscaler-sharp"))
|
||||
|
||||
blurry_latent = upscale_latent(
|
||||
latent, steps=steps, upscale_prompt="blurry, low detail, 360p"
|
||||
)
|
||||
blurry_latent = upscale_latent(
|
||||
blurry_latent, steps=steps, upscale_prompt="blurry, low detail, 360p"
|
||||
)
|
||||
upscaled.append((decode(blurry_latent), "riverwing-upscaler-blurry"))
|
||||
|
||||
# upscaled.append((decode(latent).resize(), "original"))
|
||||
|
||||
for img, name in upscaled:
|
||||
img.resize((img.width, img.height), resample=Image.NEAREST).save(
|
||||
f"{filename_base_for_outputs}_{name}.jpg"
|
||||
)
|
Loading…
Reference in New Issue
Block a user