You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/modules/sgm/diffusion.py

320 lines
11 KiB
Python

"""Classes for stochastic diffusion models"""
import logging
import math
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from imaginairy.modules.ema import LitEma
from imaginairy.modules.sgm.autoencoding.temporal_ae import VideoDecoder
from imaginairy.utils import (
default,
disabled_train,
get_obj_from_str,
instantiate_from_config,
platform_appropriate_autocast,
)
logger = logging.getLogger(__name__)
UNCONDITIONAL_CONFIG = {
"target": "imaginairy.modules.sgm.encoders.modules.GeneralConditioner",
"params": {"emb_models": []},
}
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"
class DiffusionEngine(nn.Module):
def __init__(
self,
network_config,
denoiser_config,
first_stage_config,
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
network_wrapper: Union[None, str] = None,
ckpt_path: Union[None, str] = None,
use_ema: bool = False,
ema_decay_rate: float = 0.9999,
scale_factor: float = 1.0,
disable_first_stage_autocast=False,
input_key: str = "jpg",
log_keys: Union[List, None] = None,
no_cond_log: bool = False,
compile_model: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None,
):
super().__init__()
self.log_keys = log_keys
self.input_key = input_key
self.optimizer_config = default(
optimizer_config, {"target": "torch.optim.AdamW"}
)
model = instantiate_from_config(network_config)
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
model, compile_model=compile_model
)
self.denoiser = instantiate_from_config(denoiser_config)
self.sampler = (
instantiate_from_config(sampler_config)
if sampler_config is not None
else None
)
self.conditioner = instantiate_from_config(
default(conditioner_config, UNCONDITIONAL_CONFIG)
)
self.scheduler_config = scheduler_config
self._init_first_stage(first_stage_config)
self.loss_fn = (
instantiate_from_config(loss_fn_config)
if loss_fn_config is not None
else None
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
self.no_cond_log = no_cond_log
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path)
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
def init_from_ckpt(
self,
path: str,
) -> None:
if path.endswith("ckpt"):
sd = torch.load(path, map_location="cpu")["state_dict"]
elif path.endswith("safetensors"):
sd = load_safetensors(path)
else:
raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False)
logger.info(
f"Loaded weights from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
model.train = disabled_train
for param in model.parameters():
param.requires_grad = False
self.first_stage_model = model
def get_input(self, batch):
# assuming unified data format, dataloader returns a dict.
# image tensors should be scaled to -1 ... 1 and in bchw format
return batch[self.input_key]
@torch.no_grad()
def decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
n_rounds = math.ceil(z.shape[0] / n_samples)
all_out = []
with platform_appropriate_autocast(
enabled=not self.disable_first_stage_autocast
):
for n in range(n_rounds):
if isinstance(self.first_stage_model.decoder, VideoDecoder):
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
else:
kwargs = {}
out = self.first_stage_model.decode(
z[n * n_samples : (n + 1) * n_samples], **kwargs
)
all_out.append(out)
out = torch.cat(all_out, dim=0)
return out
@torch.no_grad()
def encode_first_stage(self, x):
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
n_rounds = math.ceil(x.shape[0] / n_samples)
all_out = []
with platform_appropriate_autocast(
enabled=not self.disable_first_stage_autocast
):
for n in range(n_rounds):
out = self.first_stage_model.encode(
x[n * n_samples : (n + 1) * n_samples]
)
all_out.append(out)
z = torch.cat(all_out, dim=0)
z = self.scale_factor * z
return z
def forward(self, x, batch):
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
loss_mean = loss.mean()
loss_dict = {"loss": loss_mean}
return loss_mean, loss_dict
def shared_step(self, batch: Dict) -> Any:
x = self.get_input(batch)
x = self.encode_first_stage(x)
batch["global_step"] = self.global_step
loss, loss_dict = self(x, batch)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
self.log(
"global_step",
self.global_step,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=False,
)
if self.scheduler_config is not None:
lr = self.optimizers().param_groups[0]["lr"]
self.log(
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
)
return loss
def on_train_start(self, *args, **kwargs):
if self.sampler is None or self.loss_fn is None:
msg = "Sampler and loss function need to be set for training."
raise ValueError(msg)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", {}))
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
for embedder in self.conditioner.embedders:
if embedder.is_trainable:
params = params + list(embedder.parameters())
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def sample(
self,
cond: Dict,
uc: Union[Dict, None] = None,
batch_size: int = 16,
shape: Union[None, Tuple, List] = None,
**kwargs,
):
randn = torch.randn(batch_size, *shape).to(self.device)
def denoiser(input_tensor, sigma, c):
return self.denoiser(self.model, input_tensor, sigma, c, **kwargs)
samples = self.sampler(denoiser, randn, cond, uc=uc)
return samples
@torch.no_grad()
def log_images(
self,
batch: Dict,
N: int = 8,
sample: bool = True,
ucg_keys: Optional[List[str]] = None,
**kwargs,
) -> Dict:
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
if ucg_keys:
assert all(x in conditioner_input_keys for x in ucg_keys), (
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
)
else:
ucg_keys = conditioner_input_keys
log = {}
x = self.get_input(batch)
c, uc = self.conditioner.get_unconditional_conditioning(
batch,
force_uc_zero_embeddings=ucg_keys
if len(self.conditioner.embedders) > 0
else [],
)
sampling_kwargs = {}
N = min(x.shape[0], N)
x = x.to(self.device)[:N]
log["inputs"] = x
z = self.encode_first_stage(x)
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):
c[k], uc[k] = (y[k][:N].to(self.device) for y in (c, uc))
if sample:
with self.ema_scope("Plotting"):
samples = self.sample(
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
)
samples = self.decode_first_stage(samples)
log["samples"] = samples
return log