mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
318 lines
11 KiB
Python
318 lines
11 KiB
Python
import logging
|
|
import math
|
|
from contextlib import contextmanager
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import pytorch_lightning as pl
|
|
import torch
|
|
from omegaconf import ListConfig, OmegaConf
|
|
from safetensors.torch import load_file as load_safetensors
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
from imaginairy.modules.ema import LitEma
|
|
from imaginairy.modules.sgm.autoencoding.temporal_ae import VideoDecoder
|
|
from imaginairy.utils import (
|
|
default,
|
|
disabled_train,
|
|
get_obj_from_str,
|
|
instantiate_from_config,
|
|
platform_appropriate_autocast,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
UNCONDITIONAL_CONFIG = {
|
|
"target": "imaginairy.modules.sgm.encoders.modules.GeneralConditioner",
|
|
"params": {"emb_models": []},
|
|
}
|
|
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"
|
|
|
|
|
|
class DiffusionEngine(pl.LightningModule):
|
|
def __init__(
|
|
self,
|
|
network_config,
|
|
denoiser_config,
|
|
first_stage_config,
|
|
conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
|
sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
|
optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
|
scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
|
loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
|
|
network_wrapper: Union[None, str] = None,
|
|
ckpt_path: Union[None, str] = None,
|
|
use_ema: bool = False,
|
|
ema_decay_rate: float = 0.9999,
|
|
scale_factor: float = 1.0,
|
|
disable_first_stage_autocast=False,
|
|
input_key: str = "jpg",
|
|
log_keys: Union[List, None] = None,
|
|
no_cond_log: bool = False,
|
|
compile_model: bool = False,
|
|
en_and_decode_n_samples_a_time: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
self.log_keys = log_keys
|
|
self.input_key = input_key
|
|
self.optimizer_config = default(
|
|
optimizer_config, {"target": "torch.optim.AdamW"}
|
|
)
|
|
model = instantiate_from_config(network_config)
|
|
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
|
model, compile_model=compile_model
|
|
)
|
|
|
|
self.denoiser = instantiate_from_config(denoiser_config)
|
|
self.sampler = (
|
|
instantiate_from_config(sampler_config)
|
|
if sampler_config is not None
|
|
else None
|
|
)
|
|
self.conditioner = instantiate_from_config(
|
|
default(conditioner_config, UNCONDITIONAL_CONFIG)
|
|
)
|
|
self.scheduler_config = scheduler_config
|
|
self._init_first_stage(first_stage_config)
|
|
|
|
self.loss_fn = (
|
|
instantiate_from_config(loss_fn_config)
|
|
if loss_fn_config is not None
|
|
else None
|
|
)
|
|
|
|
self.use_ema = use_ema
|
|
if self.use_ema:
|
|
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
|
|
|
self.scale_factor = scale_factor
|
|
self.disable_first_stage_autocast = disable_first_stage_autocast
|
|
self.no_cond_log = no_cond_log
|
|
|
|
if ckpt_path is not None:
|
|
self.init_from_ckpt(ckpt_path)
|
|
|
|
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
|
|
|
|
def init_from_ckpt(
|
|
self,
|
|
path: str,
|
|
) -> None:
|
|
if path.endswith("ckpt"):
|
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
|
elif path.endswith("safetensors"):
|
|
sd = load_safetensors(path)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
|
logger.info(
|
|
f"Loaded weights from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
|
)
|
|
if len(missing) > 0:
|
|
print(f"Missing Keys: {missing}")
|
|
if len(unexpected) > 0:
|
|
print(f"Unexpected Keys: {unexpected}")
|
|
|
|
def _init_first_stage(self, config):
|
|
model = instantiate_from_config(config).eval()
|
|
model.train = disabled_train
|
|
for param in model.parameters():
|
|
param.requires_grad = False
|
|
self.first_stage_model = model
|
|
|
|
def get_input(self, batch):
|
|
# assuming unified data format, dataloader returns a dict.
|
|
# image tensors should be scaled to -1 ... 1 and in bchw format
|
|
return batch[self.input_key]
|
|
|
|
@torch.no_grad()
|
|
def decode_first_stage(self, z):
|
|
z = 1.0 / self.scale_factor * z
|
|
n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
|
|
|
|
n_rounds = math.ceil(z.shape[0] / n_samples)
|
|
all_out = []
|
|
with platform_appropriate_autocast(
|
|
enabled=not self.disable_first_stage_autocast
|
|
):
|
|
for n in range(n_rounds):
|
|
if isinstance(self.first_stage_model.decoder, VideoDecoder):
|
|
kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
|
|
else:
|
|
kwargs = {}
|
|
out = self.first_stage_model.decode(
|
|
z[n * n_samples : (n + 1) * n_samples], **kwargs
|
|
)
|
|
all_out.append(out)
|
|
out = torch.cat(all_out, dim=0)
|
|
return out
|
|
|
|
@torch.no_grad()
|
|
def encode_first_stage(self, x):
|
|
n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
|
|
n_rounds = math.ceil(x.shape[0] / n_samples)
|
|
all_out = []
|
|
with platform_appropriate_autocast(
|
|
enabled=not self.disable_first_stage_autocast
|
|
):
|
|
for n in range(n_rounds):
|
|
out = self.first_stage_model.encode(
|
|
x[n * n_samples : (n + 1) * n_samples]
|
|
)
|
|
all_out.append(out)
|
|
z = torch.cat(all_out, dim=0)
|
|
z = self.scale_factor * z
|
|
return z
|
|
|
|
def forward(self, x, batch):
|
|
loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
|
|
loss_mean = loss.mean()
|
|
loss_dict = {"loss": loss_mean}
|
|
return loss_mean, loss_dict
|
|
|
|
def shared_step(self, batch: Dict) -> Any:
|
|
x = self.get_input(batch)
|
|
x = self.encode_first_stage(x)
|
|
batch["global_step"] = self.global_step
|
|
loss, loss_dict = self(x, batch)
|
|
return loss, loss_dict
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
loss, loss_dict = self.shared_step(batch)
|
|
|
|
self.log_dict(
|
|
loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
|
)
|
|
|
|
self.log(
|
|
"global_step",
|
|
self.global_step,
|
|
prog_bar=True,
|
|
logger=True,
|
|
on_step=True,
|
|
on_epoch=False,
|
|
)
|
|
|
|
if self.scheduler_config is not None:
|
|
lr = self.optimizers().param_groups[0]["lr"]
|
|
self.log(
|
|
"lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
|
|
)
|
|
|
|
return loss
|
|
|
|
def on_train_start(self, *args, **kwargs):
|
|
if self.sampler is None or self.loss_fn is None:
|
|
msg = "Sampler and loss function need to be set for training."
|
|
raise ValueError(msg)
|
|
|
|
def on_train_batch_end(self, *args, **kwargs):
|
|
if self.use_ema:
|
|
self.model_ema(self.model)
|
|
|
|
@contextmanager
|
|
def ema_scope(self, context=None):
|
|
if self.use_ema:
|
|
self.model_ema.store(self.model.parameters())
|
|
self.model_ema.copy_to(self.model)
|
|
if context is not None:
|
|
print(f"{context}: Switched to EMA weights")
|
|
try:
|
|
yield None
|
|
finally:
|
|
if self.use_ema:
|
|
self.model_ema.restore(self.model.parameters())
|
|
if context is not None:
|
|
print(f"{context}: Restored training weights")
|
|
|
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
|
return get_obj_from_str(cfg["target"])(params, lr=lr, **cfg.get("params", {}))
|
|
|
|
def configure_optimizers(self):
|
|
lr = self.learning_rate
|
|
params = list(self.model.parameters())
|
|
for embedder in self.conditioner.embedders:
|
|
if embedder.is_trainable:
|
|
params = params + list(embedder.parameters())
|
|
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
|
if self.scheduler_config is not None:
|
|
scheduler = instantiate_from_config(self.scheduler_config)
|
|
print("Setting up LambdaLR scheduler...")
|
|
scheduler = [
|
|
{
|
|
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
|
"interval": "step",
|
|
"frequency": 1,
|
|
}
|
|
]
|
|
return [opt], scheduler
|
|
return opt
|
|
|
|
@torch.no_grad()
|
|
def sample(
|
|
self,
|
|
cond: Dict,
|
|
uc: Union[Dict, None] = None,
|
|
batch_size: int = 16,
|
|
shape: Union[None, Tuple, List] = None,
|
|
**kwargs,
|
|
):
|
|
randn = torch.randn(batch_size, *shape).to(self.device)
|
|
|
|
def denoiser(input_tensor, sigma, c):
|
|
return self.denoiser(self.model, input_tensor, sigma, c, **kwargs)
|
|
|
|
samples = self.sampler(denoiser, randn, cond, uc=uc)
|
|
return samples
|
|
|
|
@torch.no_grad()
|
|
def log_images(
|
|
self,
|
|
batch: Dict,
|
|
N: int = 8,
|
|
sample: bool = True,
|
|
ucg_keys: Optional[List[str]] = None,
|
|
**kwargs,
|
|
) -> Dict:
|
|
conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
|
|
if ucg_keys:
|
|
assert all(x in conditioner_input_keys for x in ucg_keys), (
|
|
"Each defined ucg key for sampling must be in the provided conditioner input keys,"
|
|
f"but we have {ucg_keys} vs. {conditioner_input_keys}"
|
|
)
|
|
else:
|
|
ucg_keys = conditioner_input_keys
|
|
log = {}
|
|
|
|
x = self.get_input(batch)
|
|
|
|
c, uc = self.conditioner.get_unconditional_conditioning(
|
|
batch,
|
|
force_uc_zero_embeddings=ucg_keys
|
|
if len(self.conditioner.embedders) > 0
|
|
else [],
|
|
)
|
|
|
|
sampling_kwargs = {}
|
|
|
|
N = min(x.shape[0], N)
|
|
x = x.to(self.device)[:N]
|
|
log["inputs"] = x
|
|
z = self.encode_first_stage(x)
|
|
log["reconstructions"] = self.decode_first_stage(z)
|
|
log.update(self.log_conditionings(batch, N))
|
|
|
|
for k in c:
|
|
if isinstance(c[k], torch.Tensor):
|
|
c[k], uc[k] = (y[k][:N].to(self.device) for y in (c, uc))
|
|
|
|
if sample:
|
|
with self.ema_scope("Plotting"):
|
|
samples = self.sample(
|
|
c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
|
|
)
|
|
samples = self.decode_first_stage(samples)
|
|
log["samples"] = samples
|
|
return log
|