mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
629 lines
22 KiB
Python
629 lines
22 KiB
Python
"""Classes for autoencoder-based image generation"""
|
|
|
|
import logging
|
|
import math
|
|
import re
|
|
from abc import abstractmethod
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from packaging import version
|
|
|
|
from imaginairy.modules.ema import LitEma
|
|
from imaginairy.utils import (
|
|
default,
|
|
get_nested_attribute,
|
|
get_obj_from_str,
|
|
instantiate_from_config,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .autoencoding.regularizers import AbstractRegularizer
|
|
|
|
|
|
# from .ema import LitEma
|
|
# from .util import (default, get_nested_attribute, get_obj_from_str,
|
|
# instantiate_from_config)
|
|
|
|
logpy = logging.getLogger(__name__)
|
|
|
|
|
|
class AbstractAutoencoder(nn.Module):
|
|
"""
|
|
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
|
|
unCLIP models, etc. Hence, it is fairly general, and specific features
|
|
(e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ema_decay: Union[None, float] = None,
|
|
monitor: Union[None, str] = None,
|
|
input_key: str = "jpg",
|
|
):
|
|
super().__init__()
|
|
|
|
self.input_key = input_key
|
|
self.use_ema = ema_decay is not None
|
|
if monitor is not None:
|
|
self.monitor = monitor
|
|
|
|
if self.use_ema:
|
|
self.model_ema = LitEma(self, decay=ema_decay)
|
|
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
|
|
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
|
self.automatic_optimization = False
|
|
|
|
def apply_ckpt(self, ckpt: Union[None, str, dict]):
|
|
if ckpt is None:
|
|
return
|
|
if isinstance(ckpt, str):
|
|
ckpt = {
|
|
"target": "imaginairy.modules.sgm.checkpoint.CheckpointEngine",
|
|
"params": {"ckpt_path": ckpt},
|
|
}
|
|
engine = instantiate_from_config(ckpt)
|
|
engine(self)
|
|
|
|
@abstractmethod
|
|
def get_input(self, batch) -> Any:
|
|
raise NotImplementedError()
|
|
|
|
def on_train_batch_end(self, *args, **kwargs):
|
|
# for EMA computation
|
|
if self.use_ema:
|
|
self.model_ema(self)
|
|
|
|
@contextmanager
|
|
def ema_scope(self, context=None):
|
|
if self.use_ema:
|
|
self.model_ema.store(self.parameters())
|
|
self.model_ema.copy_to(self)
|
|
if context is not None:
|
|
logpy.info(f"{context}: Switched to EMA weights")
|
|
try:
|
|
yield None
|
|
finally:
|
|
if self.use_ema:
|
|
self.model_ema.restore(self.parameters())
|
|
if context is not None:
|
|
logpy.info(f"{context}: Restored training weights")
|
|
|
|
@abstractmethod
|
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError("encode()-method of abstract base class called")
|
|
|
|
@abstractmethod
|
|
def decode(self, *args, **kwargs) -> torch.Tensor:
|
|
raise NotImplementedError("decode()-method of abstract base class called")
|
|
|
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
|
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
|
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", {}))
|
|
|
|
def configure_optimizers(self) -> Any:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class AutoencodingEngine(AbstractAutoencoder):
|
|
"""
|
|
Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
|
|
(we also restore them explicitly as special cases for legacy reasons).
|
|
Regularizations such as KL or VQ are moved to the regularizer class.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*args,
|
|
encoder_config: Dict,
|
|
decoder_config: Dict,
|
|
loss_config: Dict,
|
|
regularizer_config: Dict,
|
|
optimizer_config: Union[Dict, None] = None,
|
|
lr_g_factor: float = 1.0,
|
|
trainable_ae_params: Optional[List[List[str]]] = None,
|
|
ae_optimizer_args: Optional[List[dict]] = None,
|
|
trainable_disc_params: Optional[List[List[str]]] = None,
|
|
disc_optimizer_args: Optional[List[dict]] = None,
|
|
disc_start_iter: int = 0,
|
|
diff_boost_factor: float = 3.0,
|
|
ckpt_engine: Union[None, str, dict] = None,
|
|
ckpt_path: Optional[str] = None,
|
|
additional_decode_keys: Optional[List[str]] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(*args, **kwargs)
|
|
self.automatic_optimization = False # pytorch lightning
|
|
|
|
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
|
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
|
self.loss: torch.nn.Module = instantiate_from_config(loss_config)
|
|
self.regularization: AbstractRegularizer = instantiate_from_config(
|
|
regularizer_config
|
|
)
|
|
self.optimizer_config = default(
|
|
optimizer_config, {"target": "torch.optim.Adam"}
|
|
)
|
|
self.diff_boost_factor = diff_boost_factor
|
|
self.disc_start_iter = disc_start_iter
|
|
self.lr_g_factor = lr_g_factor
|
|
self.trainable_ae_params = trainable_ae_params
|
|
if self.trainable_ae_params is not None:
|
|
self.ae_optimizer_args = default(
|
|
ae_optimizer_args,
|
|
[{} for _ in range(len(self.trainable_ae_params))],
|
|
)
|
|
assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
|
|
else:
|
|
self.ae_optimizer_args = [{}] # makes type consitent
|
|
|
|
self.trainable_disc_params = trainable_disc_params
|
|
if self.trainable_disc_params is not None:
|
|
self.disc_optimizer_args = default(
|
|
disc_optimizer_args,
|
|
[{} for _ in range(len(self.trainable_disc_params))],
|
|
)
|
|
assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
|
|
else:
|
|
self.disc_optimizer_args = [{}] # makes type consitent
|
|
|
|
if ckpt_path is not None:
|
|
assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
|
|
logpy.warning(
|
|
"Checkpoint path is deprecated, use `checkpoint_egnine` instead"
|
|
)
|
|
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
|
self.additional_decode_keys = set(default(additional_decode_keys, []))
|
|
|
|
def get_input(self, batch: Dict) -> torch.Tensor:
|
|
# assuming unified data format, dataloader returns a dict.
|
|
# image tensors should be scaled to -1 ... 1 and in channels-first
|
|
# format (e.g., bchw instead if bhwc)
|
|
return batch[self.input_key]
|
|
|
|
def get_autoencoder_params(self) -> list:
|
|
params = []
|
|
if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
|
|
params += list(self.loss.get_trainable_autoencoder_parameters())
|
|
if hasattr(self.regularization, "get_trainable_parameters"):
|
|
params += list(self.regularization.get_trainable_parameters())
|
|
params = params + list(self.encoder.parameters())
|
|
params = params + list(self.decoder.parameters())
|
|
return params
|
|
|
|
def get_discriminator_params(self) -> list:
|
|
if hasattr(self.loss, "get_trainable_parameters"):
|
|
params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
|
|
else:
|
|
params = []
|
|
return params
|
|
|
|
def get_last_layer(self):
|
|
return self.decoder.get_last_layer()
|
|
|
|
def encode(
|
|
self,
|
|
x: torch.Tensor,
|
|
return_reg_log: bool = False,
|
|
unregularized: bool = False,
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
|
z = self.encoder(x)
|
|
if unregularized:
|
|
return z, {}
|
|
z, reg_log = self.regularization(z)
|
|
if return_reg_log:
|
|
return z, reg_log
|
|
return z
|
|
|
|
def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
|
|
x = self.decoder(z, **kwargs)
|
|
return x
|
|
|
|
def forward(
|
|
self, x: torch.Tensor, **additional_decode_kwargs
|
|
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
|
z, reg_log = self.encode(x, return_reg_log=True)
|
|
dec = self.decode(z, **additional_decode_kwargs)
|
|
return z, dec, reg_log
|
|
|
|
def inner_training_step(
|
|
self, batch: dict, batch_idx: int, optimizer_idx: int = 0
|
|
) -> torch.Tensor:
|
|
x = self.get_input(batch)
|
|
additional_decode_kwargs = {
|
|
key: batch[key] for key in self.additional_decode_keys.intersection(batch)
|
|
}
|
|
z, xrec, regularization_log = self(x, **additional_decode_kwargs)
|
|
if hasattr(self.loss, "forward_keys"):
|
|
extra_info = {
|
|
"z": z,
|
|
"optimizer_idx": optimizer_idx,
|
|
"global_step": self.global_step,
|
|
"last_layer": self.get_last_layer(),
|
|
"split": "train",
|
|
"regularization_log": regularization_log,
|
|
"autoencoder": self,
|
|
}
|
|
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
|
else:
|
|
extra_info = {}
|
|
|
|
if optimizer_idx == 0:
|
|
# autoencode
|
|
out_loss = self.loss(x, xrec, **extra_info)
|
|
if isinstance(out_loss, tuple):
|
|
aeloss, log_dict_ae = out_loss
|
|
else:
|
|
# simple loss function
|
|
aeloss = out_loss
|
|
log_dict_ae = {"train/loss/rec": aeloss.detach()}
|
|
|
|
self.log_dict(
|
|
log_dict_ae,
|
|
prog_bar=False,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
sync_dist=False,
|
|
)
|
|
self.log(
|
|
"loss",
|
|
aeloss.mean().detach(),
|
|
prog_bar=True,
|
|
logger=False,
|
|
on_epoch=False,
|
|
on_step=True,
|
|
)
|
|
return aeloss
|
|
elif optimizer_idx == 1:
|
|
# discriminator
|
|
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
|
# -> discriminator always needs to return a tuple
|
|
self.log_dict(
|
|
log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
|
|
)
|
|
return discloss
|
|
else:
|
|
msg = f"Unknown optimizer {optimizer_idx}"
|
|
raise NotImplementedError(msg)
|
|
|
|
def training_step(self, batch: dict, batch_idx: int):
|
|
opts = self.optimizers()
|
|
if not isinstance(opts, list):
|
|
# Non-adversarial case
|
|
opts = [opts]
|
|
optimizer_idx = batch_idx % len(opts)
|
|
if self.global_step < self.disc_start_iter:
|
|
optimizer_idx = 0
|
|
opt = opts[optimizer_idx]
|
|
opt.zero_grad()
|
|
with opt.toggle_model():
|
|
loss = self.inner_training_step(
|
|
batch, batch_idx, optimizer_idx=optimizer_idx
|
|
)
|
|
self.manual_backward(loss)
|
|
opt.step()
|
|
|
|
def validation_step(self, batch: dict, batch_idx: int) -> Dict:
|
|
log_dict = self._validation_step(batch, batch_idx)
|
|
with self.ema_scope():
|
|
log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
|
|
log_dict.update(log_dict_ema)
|
|
return log_dict
|
|
|
|
def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
|
|
x = self.get_input(batch)
|
|
|
|
z, xrec, regularization_log = self(x)
|
|
if hasattr(self.loss, "forward_keys"):
|
|
extra_info = {
|
|
"z": z,
|
|
"optimizer_idx": 0,
|
|
"global_step": self.global_step,
|
|
"last_layer": self.get_last_layer(),
|
|
"split": "val" + postfix,
|
|
"regularization_log": regularization_log,
|
|
"autoencoder": self,
|
|
}
|
|
extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
|
|
else:
|
|
extra_info = {}
|
|
out_loss = self.loss(x, xrec, **extra_info)
|
|
if isinstance(out_loss, tuple):
|
|
aeloss, log_dict_ae = out_loss
|
|
else:
|
|
# simple loss function
|
|
aeloss = out_loss
|
|
log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
|
|
full_log_dict = log_dict_ae
|
|
|
|
if "optimizer_idx" in extra_info:
|
|
extra_info["optimizer_idx"] = 1
|
|
discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
|
|
full_log_dict.update(log_dict_disc)
|
|
self.log(
|
|
f"val{postfix}/loss/rec",
|
|
log_dict_ae[f"val{postfix}/loss/rec"],
|
|
sync_dist=True,
|
|
)
|
|
self.log_dict(full_log_dict, sync_dist=True)
|
|
return full_log_dict
|
|
|
|
def get_param_groups(
|
|
self, parameter_names: List[List[str]], optimizer_args: List[dict]
|
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
groups = []
|
|
num_params = 0
|
|
for names, args in zip(parameter_names, optimizer_args):
|
|
params = []
|
|
for pattern_ in names:
|
|
pattern_params = []
|
|
pattern = re.compile(pattern_)
|
|
for p_name, param in self.named_parameters():
|
|
if re.match(pattern, p_name):
|
|
pattern_params.append(param)
|
|
num_params += param.numel()
|
|
if len(pattern_params) == 0:
|
|
logpy.warning(f"Did not find parameters for pattern {pattern_}")
|
|
params.extend(pattern_params)
|
|
groups.append({"params": params, **args})
|
|
return groups, num_params
|
|
|
|
def configure_optimizers(self) -> List[torch.optim.Optimizer]:
|
|
if self.trainable_ae_params is None:
|
|
ae_params = self.get_autoencoder_params()
|
|
else:
|
|
ae_params, num_ae_params = self.get_param_groups(
|
|
self.trainable_ae_params, self.ae_optimizer_args
|
|
)
|
|
logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
|
|
if self.trainable_disc_params is None:
|
|
disc_params = self.get_discriminator_params()
|
|
else:
|
|
disc_params, num_disc_params = self.get_param_groups(
|
|
self.trainable_disc_params, self.disc_optimizer_args
|
|
)
|
|
logpy.info(
|
|
f"Number of trainable discriminator parameters: {num_disc_params:,}"
|
|
)
|
|
opt_ae = self.instantiate_optimizer_from_config(
|
|
ae_params,
|
|
default(self.lr_g_factor, 1.0) * self.learning_rate,
|
|
self.optimizer_config,
|
|
)
|
|
opts = [opt_ae]
|
|
if len(disc_params) > 0:
|
|
opt_disc = self.instantiate_optimizer_from_config(
|
|
disc_params, self.learning_rate, self.optimizer_config
|
|
)
|
|
opts.append(opt_disc)
|
|
|
|
return opts
|
|
|
|
@torch.no_grad()
|
|
def log_images(
|
|
self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
|
|
) -> dict:
|
|
log = {}
|
|
additional_decode_kwargs = {}
|
|
x = self.get_input(batch)
|
|
additional_decode_kwargs.update(
|
|
{key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
|
|
)
|
|
|
|
_, xrec, _ = self(x, **additional_decode_kwargs)
|
|
log["inputs"] = x
|
|
log["reconstructions"] = xrec
|
|
diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
|
|
diff.clamp_(0, 1.0)
|
|
log["diff"] = 2.0 * diff - 1.0
|
|
# diff_boost shows location of small errors, by boosting their
|
|
# brightness.
|
|
log["diff_boost"] = (
|
|
2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
|
|
)
|
|
if hasattr(self.loss, "log_images"):
|
|
log.update(self.loss.log_images(x, xrec))
|
|
with self.ema_scope():
|
|
_, xrec_ema, _ = self(x, **additional_decode_kwargs)
|
|
log["reconstructions_ema"] = xrec_ema
|
|
diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
|
|
diff_ema.clamp_(0, 1.0)
|
|
log["diff_ema"] = 2.0 * diff_ema - 1.0
|
|
log["diff_boost_ema"] = (
|
|
2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
|
|
)
|
|
if additional_log_kwargs:
|
|
additional_decode_kwargs.update(additional_log_kwargs)
|
|
_, xrec_add, _ = self(x, **additional_decode_kwargs)
|
|
log_str = "reconstructions-" + "-".join(
|
|
[f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
|
|
)
|
|
log[log_str] = xrec_add
|
|
return log
|
|
|
|
|
|
class AutoencodingEngineLegacy(AutoencodingEngine):
|
|
def __init__(self, embed_dim: int, **kwargs):
|
|
self.max_batch_size = kwargs.pop("max_batch_size", None)
|
|
ddconfig = kwargs.pop("ddconfig")
|
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
|
ckpt_engine = kwargs.pop("ckpt_engine", None)
|
|
super().__init__(
|
|
encoder_config={
|
|
"target": "imaginairy.modules.sgm.diffusionmodules.model.Encoder",
|
|
"params": ddconfig,
|
|
},
|
|
decoder_config={
|
|
"target": "imaginairy.modules.sgm.diffusionmodules.model.Decoder",
|
|
"params": ddconfig,
|
|
},
|
|
**kwargs,
|
|
)
|
|
self.quant_conv = torch.nn.Conv2d(
|
|
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
|
(1 + ddconfig["double_z"]) * embed_dim,
|
|
1,
|
|
)
|
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
|
self.embed_dim = embed_dim
|
|
|
|
self.apply_ckpt(default(ckpt_path, ckpt_engine))
|
|
|
|
def get_autoencoder_params(self) -> list:
|
|
params = super().get_autoencoder_params()
|
|
return params
|
|
|
|
def encode(
|
|
self, x: torch.Tensor, return_reg_log: bool = False
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
|
if self.max_batch_size is None:
|
|
z = self.encoder(x)
|
|
z = self.quant_conv(z)
|
|
else:
|
|
N = x.shape[0]
|
|
bs = self.max_batch_size
|
|
n_batches = int(math.ceil(N / bs))
|
|
z = []
|
|
for i_batch in range(n_batches):
|
|
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
|
z_batch = self.quant_conv(z_batch)
|
|
z.append(z_batch)
|
|
z = torch.cat(z, 0)
|
|
|
|
z, reg_log = self.regularization(z)
|
|
if return_reg_log:
|
|
return z, reg_log
|
|
return z
|
|
|
|
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
|
if self.max_batch_size is None:
|
|
dec = self.post_quant_conv(z)
|
|
dec = self.decoder(dec, **decoder_kwargs)
|
|
else:
|
|
N = z.shape[0]
|
|
bs = self.max_batch_size
|
|
n_batches = int(math.ceil(N / bs))
|
|
dec = []
|
|
for i_batch in range(n_batches):
|
|
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
|
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
|
dec.append(dec_batch)
|
|
dec = torch.cat(dec, 0)
|
|
|
|
return dec
|
|
|
|
|
|
class AutoencoderKL(AutoencodingEngineLegacy):
|
|
def __init__(self, **kwargs):
|
|
if "lossconfig" in kwargs:
|
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
|
super().__init__(
|
|
regularizer_config={
|
|
"target": (
|
|
"imaginairy.modules.sgm.autoencoding.regularizers"
|
|
".DiagonalGaussianRegularizer"
|
|
)
|
|
},
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
n_embed: int,
|
|
sane_index_shape: bool = False,
|
|
**kwargs,
|
|
):
|
|
if "lossconfig" in kwargs:
|
|
logpy.warning("Parameter `lossconfig` is deprecated, use `loss_config`.")
|
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
|
super().__init__(
|
|
regularizer_config={
|
|
"target": (
|
|
"imaginairy.modules.sgm.autoencoding.regularizers.quantize"
|
|
".VectorQuantizer"
|
|
),
|
|
"params": {
|
|
"n_e": n_embed,
|
|
"e_dim": embed_dim,
|
|
"sane_index_shape": sane_index_shape,
|
|
},
|
|
},
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class IdentityFirstStage(AbstractAutoencoder):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def get_input(self, x: Any) -> Any:
|
|
return x
|
|
|
|
def encode(self, x: Any, *args, **kwargs) -> Any:
|
|
return x
|
|
|
|
def decode(self, x: Any, *args, **kwargs) -> Any:
|
|
return x
|
|
|
|
|
|
class AEIntegerWrapper(nn.Module):
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
|
|
regularization_key: str = "regularization",
|
|
encoder_kwargs: Optional[Dict[str, Any]] = None,
|
|
):
|
|
super().__init__()
|
|
self.model = model
|
|
if not hasattr(model, "encode") or hasattr(model, "decode"):
|
|
raise RuntimeError("Need AE interface")
|
|
self.regularization = get_nested_attribute(model, regularization_key)
|
|
self.shape = shape
|
|
self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
|
|
|
|
def encode(self, x) -> torch.Tensor:
|
|
assert (
|
|
not self.training
|
|
), f"{self.__class__.__name__} only supports inference currently"
|
|
_, log = self.model.encode(x, **self.encoder_kwargs)
|
|
assert isinstance(log, dict)
|
|
inds = log["min_encoding_indices"]
|
|
return rearrange(inds, "b ... -> b (...)")
|
|
|
|
def decode(
|
|
self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
|
|
) -> torch.Tensor:
|
|
# expect inds shape (b, s) with s = h*w
|
|
shape = default(shape, self.shape) # Optional[(h, w)]
|
|
if shape is not None:
|
|
assert len(shape) == 2, f"Unhandeled shape {shape}"
|
|
inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
|
|
h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
|
|
h = rearrange(h, "b h w c -> b c h w")
|
|
return self.model.decode(h)
|
|
|
|
|
|
class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
|
|
def __init__(self, **kwargs):
|
|
if "lossconfig" in kwargs:
|
|
kwargs["loss_config"] = kwargs.pop("lossconfig")
|
|
super().__init__(
|
|
regularizer_config={
|
|
"target": (
|
|
"imaginairy.modules.sgm.autoencoding.regularizers"
|
|
".DiagonalGaussianRegularizer"
|
|
),
|
|
"params": {"sample": False},
|
|
},
|
|
**kwargs,
|
|
)
|