test: add some autoencoder tests

the fold-unfold encoding/decoding looks like it's slower and has worse seams than the sliced feathering approach
This commit is contained in:
Bryce 2023-02-17 09:48:50 -08:00 committed by Bryce Drennan
parent 1ceb17c083
commit 1563e0b871
5 changed files with 409 additions and 46 deletions

View File

@ -56,6 +56,21 @@ class AutoencoderKL(pl.LightningModule):
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
ks = 128
stride = 64
vqf = 8
self.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
@ -89,10 +104,12 @@ class AutoencoderKL(pl.LightningModule):
def encode(self, x):
return self.encode_sliced(x)
# h = self.encoder(x)
# moments = self.quant_conv(h)
# posterior = DiagonalGaussianDistribution(moments)
# return posterior.sample()
def encode_all_at_once(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior.mode()
def encode_sliced(self, x, chunk_size=128 * 8):
"""
@ -123,6 +140,60 @@ class AutoencoderKL(pl.LightningModule):
return final_tensor
def encode_with_folds(self, x):
bs, nc, h, w = x.shape
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
if h > ks[0] * df or w > ks[1] * df:
self.split_input_params["original_image_size"] = x.shape[-2:]
orig_shape = x.shape
if ks[0] > h // df or ks[1] > w // df:
ks = (min(ks[0], h // df), min(ks[1], w // df))
logger.debug(f"reducing Kernel to {ks}")
if stride[0] > h // df or stride[1] > w // df:
stride = (min(stride[0], h // df), min(stride[1], w // df))
logger.debug("reducing stride")
bottom_pad = math.ceil(h / (ks[0] * df)) * (ks[0] * df) - h
right_pad = math.ceil(w / (ks[1] * df)) * (ks[1] * df) - w
padded_x = torch.zeros(
(bs, nc, h + bottom_pad, w + right_pad), device=x.device
)
padded_x[:, :, :h, :w] = x
x = padded_x
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0] * df, ks[1] * df, z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.encode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
encoded = fold(o)
encoded = encoded / normalization
# trim off padding
encoded = encoded[:, :, : orig_shape[2] // df, : orig_shape[3] // df]
return encoded
return self.encode_all_at_once(x)
def decode(self, z):
try:
return self.decode_all_at_once(z)
@ -167,6 +238,56 @@ class AutoencoderKL(pl.LightningModule):
return final_tensor
def decode_with_folds(self, z):
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
orig_shape = z.shape
bs, nc, h, w = z.shape
bottom_pad = math.ceil(h / ks[0]) * ks[0] - h
right_pad = math.ceil(w / ks[1]) * ks[1] - w
# pad the latent such that the unfolding will cover the whole image
padded_z = torch.zeros((bs, nc, h + bottom_pad, w + right_pad), device=z.device)
padded_z[:, :, :h, :w] = z
z = padded_z
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.debug("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.debug("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
z, ks, stride, uf=uf
)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
output_list = [
self.decode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
# trim off padding
decoded = decoded[:, :, : orig_shape[2] * 8, : orig_shape[3] * 8]
return decoded
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
if sample_posterior:
@ -321,6 +442,143 @@ class AutoencoderKL(pl.LightningModule):
log["inputs"] = x
return log
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(
torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(
weighting,
self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"],
)
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(
L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"],
)
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape # noqa
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(
kernel_size[0], kernel_size[1], Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = {
"kernel_size": (kernel_size[0] * uf, kernel_size[0] * uf),
"dilation": 1,
"padding": 0,
"stride": (stride[0] * uf, stride[1] * uf),
}
fold = torch.nn.Fold(
output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
)
weighting = self.get_weighting(
kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(
1, 1, h * uf, w * uf
) # normalizes the overlap
weighting = weighting.view(
(1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
)
elif df > 1 and uf == 1:
Ly = (h - (kernel_size[0] * df)) // (stride[0] * df) + 1
Lx = (w - (kernel_size[1] * df)) // (stride[1] * df) + 1
unfold_params = {
"kernel_size": (kernel_size[0] * df, kernel_size[1] * df),
"dilation": 1,
"padding": 0,
"stride": (stride[0] * df, stride[1] * df),
}
unfold = torch.nn.Unfold(**unfold_params)
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
fold = torch.nn.Fold(
output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params
)
weighting = self.get_weighting(
kernel_size[0], kernel_size[1], Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(
1, 1, h // df, w // df
) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
# def to_rgb(self, x):
# assert self.image_key == "segmentation"
# if not hasattr(self, "colorize"):

View File

@ -892,7 +892,7 @@ class LatentDiffusion(DDPM):
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
z = encoder_posterior.mode()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
@ -1127,47 +1127,6 @@ class LatentDiffusion(DDPM):
@torch.no_grad()
def encode_first_stage(self, x):
if (
hasattr(self, "split_input_params")
and self.split_input_params["patch_distributed_vq"]
):
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params["original_image_size"] = x.shape[-2:]
bs, nc, h, w = x.shape # noqa
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.info("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.info("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

BIN
tests/data/dog.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

View File

@ -0,0 +1,146 @@
import numpy as np
import pytest
from PIL import Image
from torch.nn.functional import interpolate
from imaginairy import LazyLoadingImage
from imaginairy.enhancers.upscale_riverwing import upscale_latent
from imaginairy.img_utils import (
pillow_fit_image_within,
pillow_img_to_torch_image,
torch_img_to_pillow_img,
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.utils import get_device
from tests import TESTS_FOLDER
strat_combos = [
("sliced", "sliced"),
("sliced", "all_at_once"),
("folds", "folds"),
("folds", "all_at_once"),
("all_at_once", "all_at_once"),
("all_at_once", "sliced"),
("all_at_once", "folds"),
]
@pytest.mark.skipif(get_device() == "cpu", reason="Too slow to run on CPU")
@pytest.mark.parametrize("encode_strat,decode_strat", strat_combos)
def test_encode_decode(filename_base_for_outputs, encode_strat, decode_strat):
"""Test that encoding and decoding works."""
model = get_diffusion_model()
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg")
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
img.save(f"{filename_base_for_outputs}_orig.png")
img_t = pillow_img_to_torch_image(img).to(get_device())
if encode_strat == "all_at_once":
latent = model.first_stage_model.encode_all_at_once(img_t) * model.scale_factor
elif encode_strat == "folds":
latent = model.first_stage_model.encode_with_folds(img_t) * model.scale_factor
else:
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
if decode_strat == "all_at_once":
decoded_img_t = model.first_stage_model.decode_all_at_once(
latent / model.scale_factor
)
elif decode_strat == "folds":
decoded_img_t = model.first_stage_model.decode_with_folds(
latent / model.scale_factor
)
else:
decoded_img_t = model.first_stage_model.decode_sliced(
latent / model.scale_factor
)
decoded_img_t = interpolate(decoded_img_t, img_t.shape[-2:])
decoded_img = torch_img_to_pillow_img(decoded_img_t)
decoded_img.save(f"{filename_base_for_outputs}.png")
diff_img = Image.fromarray(np.asarray(img) - np.asarray(decoded_img))
diff_img.save(f"{filename_base_for_outputs}_diff.png")
@pytest.mark.skip()
def test_encode_decode_naive_scale(filename_base_for_outputs):
model = get_diffusion_model()
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/dog.jpg")
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
img.save(f"{filename_base_for_outputs}_orig.png")
img_t = pillow_img_to_torch_image(img).to(get_device())
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
latent = interpolate(latent, scale_factor=2)
decoded_img_t = model.first_stage_model.decode_sliced(latent / model.scale_factor)
decoded_img = torch_img_to_pillow_img(decoded_img_t)
decoded_img.save(f"{filename_base_for_outputs}.png")
@pytest.mark.skip(reason="experimental")
def test_upscale_methods(filename_base_for_outputs, steps):
"""
compare upscale methods.
"""
steps = 25
model = get_diffusion_model()
roi_pcts = (0.7, 0.1, 0.9, 0.3)
def crop_pct(img, roi_pcts):
w, h = img.size
roi = (
int(w * roi_pcts[0]),
int(h * roi_pcts[1]),
int(w * roi_pcts[2]),
int(h * roi_pcts[3]),
)
return img.crop(roi)
def decode(latent):
t = model.first_stage_model.decode_sliced(latent / model.scale_factor)
return torch_img_to_pillow_img(t)
img = LazyLoadingImage(
filepath=f"{TESTS_FOLDER}/data/010853_1_kdpmpp2m30_PS7.5_portrait_photo_of_a_freckled_woman_[generated].jpg"
)
img = pillow_fit_image_within(img, max_height=img.height, max_width=img.width)
img = crop_pct(img, roi_pcts)
upscaled = []
sampling_methods = [
("nearest", Image.Resampling.NEAREST),
("bilinear", Image.Resampling.BILINEAR),
("bicubic", Image.Resampling.BICUBIC),
("lanczos", Image.Resampling.LANCZOS),
]
for method_name, sample_method in sampling_methods:
upscaled.append(
(
img.resize((img.width * 4, img.height * 4), resample=sample_method),
f"{method_name}",
)
)
img_t = pillow_img_to_torch_image(img).to(get_device())
latent = model.first_stage_model.encode_sliced(img_t) * model.scale_factor
sharp_latent = upscale_latent(
latent, steps=steps, upscale_prompt="high detail, sharp focus, 4k"
)
sharp_latent = upscale_latent(
sharp_latent, steps=steps, upscale_prompt="high detail, sharp focus, 4k"
)
upscaled.append((decode(sharp_latent), "riverwing-upscaler-sharp"))
blurry_latent = upscale_latent(
latent, steps=steps, upscale_prompt="blurry, low detail, 360p"
)
blurry_latent = upscale_latent(
blurry_latent, steps=steps, upscale_prompt="blurry, low detail, 360p"
)
upscaled.append((decode(blurry_latent), "riverwing-upscaler-blurry"))
# upscaled.append((decode(latent).resize(), "original"))
for img, name in upscaled:
img.resize((img.width, img.height), resample=Image.NEAREST).save(
f"{filename_base_for_outputs}_{name}.jpg"
)