You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/modules/autoencoder.py

666 lines
23 KiB
Python

"""Classes for image autoencoding and manipulation"""
# pylama:ignore=W0613
import logging
import math
from contextlib import contextmanager
import torch
from torch import nn
from torch.cuda import OutOfMemoryError
from imaginairy.modules.diffusion.model import Decoder, Encoder
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.modules.ema import LitEma
from imaginairy.utils import instantiate_from_config
from imaginairy.utils.feather_tile import rebuild_image, tile_image
logger = logging.getLogger(__name__)
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=None,
image_key="image",
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False,
):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert isinstance(colorize_nlabels, int)
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0.0 < ema_decay < 1.0
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
ks = 128
stride = 64
vqf = 8
self.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
def init_from_ckpt(self, path, ignore_keys=None):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
ignore_keys = [] if ignore_keys is None else ignore_keys
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print(f"Deleting key {k} from state_dict.")
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
return self.encode_sliced(x)
def encode_all_at_once(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior.mode()
def encode_sliced(self, x, chunk_size=128 * 8):
"""
encodes the image in slices.
"""
b, c, h, w = x.size()
final_tensor = torch.zeros(
[1, 4, math.ceil(h / 8), math.ceil(w / 8)], device=x.device
)
for x_img in x.split(1):
encoded_chunks = []
overlap_pct = 0.5
chunks = tile_image(
x_img, tile_size=chunk_size, overlap_percent=overlap_pct
)
for img_chunk in chunks:
h = self.encoder(img_chunk)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
encoded_chunks.append(posterior.sample())
final_tensor = rebuild_image(
encoded_chunks,
base_img=final_tensor,
tile_size=chunk_size // 8,
overlap_percent=overlap_pct,
)
return final_tensor
def encode_with_folds(self, x):
bs, nc, h, w = x.shape
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
if h > ks[0] * df or w > ks[1] * df:
self.split_input_params["original_image_size"] = x.shape[-2:]
orig_shape = x.shape
if ks[0] > h // df or ks[1] > w // df:
ks = (min(ks[0], h // df), min(ks[1], w // df))
logger.debug(f"reducing Kernel to {ks}")
if stride[0] > h // df or stride[1] > w // df:
stride = (min(stride[0], h // df), min(stride[1], w // df))
logger.debug("reducing stride")
bottom_pad = math.ceil(h / (ks[0] * df)) * (ks[0] * df) - h
right_pad = math.ceil(w / (ks[1] * df)) * (ks[1] * df) - w
padded_x = torch.zeros(
(bs, nc, h + bottom_pad, w + right_pad), device=x.device
)
padded_x[:, :, :h, :w] = x
x = padded_x
fold, unfold, normalization, weighting = self.get_fold_unfold(
x, ks, stride, df=df
)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0] * df, ks[1] * df, z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
output_list = [
self.encode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
encoded = fold(o)
encoded = encoded / normalization
# trim off padding
encoded = encoded[:, :, : orig_shape[2] // df, : orig_shape[3] // df]
return encoded
return self.encode_all_at_once(x)
def decode(self, z):
try:
return self.decode_all_at_once(z)
except OutOfMemoryError:
# Out of memory, trying sliced decoding.
try:
return self.decode_sliced(z, chunk_size=128)
except OutOfMemoryError:
return self.decode_sliced(z, chunk_size=64)
def decode_all_at_once(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def decode_sliced(self, z, chunk_size=128):
"""
decodes the tensor in slices.
This results in images that don't exactly match, so we overlap, feather, and merge to reduce
(but not completely elminate) impact.
"""
b, c, h, w = z.size()
final_tensor = torch.zeros([1, 3, h * 8, w * 8], device=z.device)
for z_latent in z.split(1):
decoded_chunks = []
overlap_pct = 0.5
chunks = tile_image(
z_latent, tile_size=chunk_size, overlap_percent=overlap_pct
)
for latent_chunk in chunks:
latent_chunk = self.post_quant_conv(latent_chunk)
dec = self.decoder(latent_chunk)
decoded_chunks.append(dec)
final_tensor = rebuild_image(
decoded_chunks,
base_img=final_tensor,
tile_size=chunk_size * 8,
overlap_percent=overlap_pct,
)
return final_tensor
def decode_with_folds(self, z):
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
orig_shape = z.shape
bs, nc, h, w = z.shape
bottom_pad = math.ceil(h / ks[0]) * ks[0] - h
right_pad = math.ceil(w / ks[1]) * ks[1] - w
# pad the latent such that the unfolding will cover the whole image
padded_z = torch.zeros((bs, nc, h + bottom_pad, w + right_pad), device=z.device)
padded_z[:, :, :h, :w] = z
z = padded_z
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
logger.debug("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
logger.debug("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(
z, ks, stride, uf=uf
)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view(
(z.shape[0], -1, ks[0], ks[1], z.shape[-1])
) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
output_list = [
self.decode_all_at_once(z[:, :, :, :, i]) for i in range(z.shape[-1])
]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
# trim off padding
decoded = decoded[:, :, : orig_shape[2] * 8, : orig_shape[3] * 8]
return decoded
def forward(self, input, sample_posterior=True): # noqa
posterior = self.encode(input)
z = posterior.sample() if sample_posterior else posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train",
)
self.log(
"discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True,
)
self.log_dict(
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
)
return discloss
return None
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step( # noqa
batch, batch_idx, postfix="_ema"
)
return log_dict
def _validation_step(self, batch, batch_idx, postfix=""):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val" + postfix,
)
self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = (
list(self.encoder.parameters())
+ list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters())
)
if self.learn_logvar:
print(f"{self.__class__.__name__}: Learning logvar")
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
)
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = {}
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log["samples_ema"] = self.decode(
torch.randn_like(posterior_ema.sample())
)
log["reconstructions_ema"] = xrec_ema
log["inputs"] = x
return log
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(
torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1
)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(
weighting,
self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"],
)
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(
L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"],
)
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(
kernel_size[0], kernel_size[1], Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = {
"kernel_size": (kernel_size[0] * uf, kernel_size[0] * uf),
"dilation": 1,
"padding": 0,
"stride": (stride[0] * uf, stride[1] * uf),
}
fold = torch.nn.Fold(
output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2
)
weighting = self.get_weighting(
kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(
1, 1, h * uf, w * uf
) # normalizes the overlap
weighting = weighting.view(
(1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)
)
elif df > 1 and uf == 1:
Ly = (h - (kernel_size[0] * df)) // (stride[0] * df) + 1
Lx = (w - (kernel_size[1] * df)) // (stride[1] * df) + 1
unfold_params = {
"kernel_size": (kernel_size[0] * df, kernel_size[1] * df),
"dilation": 1,
"padding": 0,
"stride": (stride[0] * df, stride[1] * df),
}
unfold = torch.nn.Unfold(**unfold_params)
fold_params = {
"kernel_size": kernel_size,
"dilation": 1,
"padding": 0,
"stride": stride,
}
fold = torch.nn.Fold(
output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params
)
weighting = self.get_weighting(
kernel_size[0], kernel_size[1], Ly, Lx, x.device
).to(x.dtype)
normalization = fold(weighting).view(
1, 1, h // df, w // df
) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
# def to_rgb(self, x):
# assert self.image_key == "segmentation"
# if not hasattr(self, "colorize"):
# self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
# x = F.conv2d(x, weight=self.colorize)
# x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
# return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
def chunk_latent(tensor, chunk_size=64, overlap_size=8):
# Get the shape of the tensor
batch_size, num_channels, height, width = tensor.shape
# Calculate the number of chunks along each dimension
num_rows = int(math.ceil(height / chunk_size))
num_cols = int(math.ceil(width / chunk_size))
# Initialize a list to store the chunks
chunks = []
# Loop over the rows and columns
for row in range(num_rows):
for col in range(num_cols):
# Calculate the start and end indices for the chunk along each dimension
row_start = max(row * chunk_size - overlap_size, 0)
row_end = min(row_start + chunk_size + overlap_size, height)
col_start = max(col * chunk_size - overlap_size, 0)
col_end = min(col_start + chunk_size + overlap_size, width)
# Extract the chunk from the tensor and append it to the list of chunks
chunk = tensor[:, :, row_start:row_end, col_start:col_end]
chunks.append((chunk, row_start, col_start))
return chunks, num_rows, num_cols
def merge_tensors(tensor_list, num_rows, num_cols):
print(f"num_rows: {num_rows}")
print(f"num_cols: {num_cols}")
n, channel, h, w = tensor_list[0].size()
assert n == 1
final_width = 0
final_height = 0
for col_idx in range(num_cols):
final_width += tensor_list[col_idx].size()[3]
for row_idx in range(num_rows):
final_height += tensor_list[row_idx * num_cols].size()[2]
final_tensor = torch.zeros([1, channel, final_height, final_width])
print(f"final size {final_tensor.size()}")
for row_idx in range(num_rows):
for col_idx in range(num_cols):
list_idx = row_idx * num_cols + col_idx
chunk = tensor_list[list_idx]
print(f"chunk size: {chunk.size()}")
_, _, chunk_h, chunk_w = chunk.size()
final_tensor[
:,
:,
row_idx * h : row_idx * h + chunk_h,
col_idx * w : col_idx * w + chunk_w,
] = chunk
return final_tensor