2022-09-09 04:51:25 +00:00
|
|
|
import logging
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
2022-09-08 03:59:30 +00:00
|
|
|
|
2022-09-11 20:58:14 +00:00
|
|
|
from imaginairy.modules.diffusion.model import Decoder, Encoder
|
2022-09-08 03:59:30 +00:00
|
|
|
from imaginairy.modules.distributions import DiagonalGaussianDistribution
|
|
|
|
from imaginairy.utils import instantiate_from_config
|
|
|
|
|
2022-09-09 04:51:25 +00:00
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
class AutoencoderKL(pl.LightningModule):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
ddconfig,
|
|
|
|
lossconfig,
|
|
|
|
embed_dim,
|
|
|
|
ckpt_path=None,
|
2022-09-24 07:29:45 +00:00
|
|
|
ignore_keys=None,
|
2022-09-08 03:59:30 +00:00
|
|
|
image_key="image",
|
|
|
|
colorize_nlabels=None,
|
|
|
|
monitor=None,
|
|
|
|
):
|
|
|
|
super().__init__()
|
2022-09-24 07:29:45 +00:00
|
|
|
ignore_keys = [] if ignore_keys is None else ignore_keys
|
2022-09-08 03:59:30 +00:00
|
|
|
self.image_key = image_key
|
|
|
|
self.encoder = Encoder(**ddconfig)
|
|
|
|
self.decoder = Decoder(**ddconfig)
|
|
|
|
self.loss = instantiate_from_config(lossconfig)
|
|
|
|
assert ddconfig["double_z"]
|
|
|
|
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
|
|
|
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
if colorize_nlabels is not None:
|
2022-09-24 07:29:45 +00:00
|
|
|
assert isinstance(colorize_nlabels, int)
|
2022-09-08 03:59:30 +00:00
|
|
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
|
|
|
if monitor is not None:
|
|
|
|
self.monitor = monitor
|
|
|
|
if ckpt_path is not None:
|
|
|
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
|
|
|
|
2022-09-24 07:29:45 +00:00
|
|
|
def init_from_ckpt(self, path, ignore_keys=None):
|
|
|
|
ignore_keys = [] if ignore_keys is None else ignore_keys
|
2022-09-08 03:59:30 +00:00
|
|
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
|
|
keys = list(sd.keys())
|
|
|
|
for k in keys:
|
|
|
|
for ik in ignore_keys:
|
|
|
|
if k.startswith(ik):
|
2022-09-24 07:29:45 +00:00
|
|
|
logger.info(f"Deleting key {k} from state_dict.")
|
2022-09-08 03:59:30 +00:00
|
|
|
del sd[k]
|
|
|
|
self.load_state_dict(sd, strict=False)
|
2022-09-09 04:51:25 +00:00
|
|
|
logger.info(f"Restored from {path}")
|
2022-09-08 03:59:30 +00:00
|
|
|
|
|
|
|
def encode(self, x):
|
|
|
|
h = self.encoder(x)
|
|
|
|
moments = self.quant_conv(h)
|
|
|
|
posterior = DiagonalGaussianDistribution(moments)
|
|
|
|
return posterior
|
|
|
|
|
|
|
|
def decode(self, z):
|
|
|
|
z = self.post_quant_conv(z)
|
|
|
|
dec = self.decoder(z)
|
|
|
|
return dec
|
|
|
|
|
2022-09-24 07:29:45 +00:00
|
|
|
def forward(self, input, sample_posterior=True): # noqa
|
2022-09-08 03:59:30 +00:00
|
|
|
posterior = self.encode(input)
|
|
|
|
if sample_posterior:
|
|
|
|
z = posterior.sample()
|
|
|
|
else:
|
|
|
|
z = posterior.mode()
|
|
|
|
dec = self.decode(z)
|
|
|
|
return dec, posterior
|
|
|
|
|
|
|
|
def get_input(self, batch, k):
|
|
|
|
x = batch[k]
|
|
|
|
if len(x.shape) == 3:
|
|
|
|
x = x[..., None]
|
|
|
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
|
|
|
return x
|