import datetime import logging import os import signal import time from functools import partial import numpy as np import pytorch_lightning as pl import torch import torchvision from omegaconf import OmegaConf from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import Callback, LearningRateMonitor try: from pytorch_lightning.strategies import DDPStrategy except ImportError: # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL # Use >= 1.6.0 to make this work DDPStrategy = None from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import rank_zero_only from torch.utils.data import DataLoader, Dataset from imaginairy import config from imaginairy.model_manager import get_diffusion_model from imaginairy.training_tools.single_concept import SingleConceptDataset from imaginairy.utils import get_device, instantiate_from_config mod_logger = logging.getLogger(__name__) referenced_by_string = [LearningRateMonitor] class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset.""" def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id if isinstance(dataset, SingleConceptDataset): # split_size = dataset.num_records // worker_info.num_workers # reset num_records to the true number to retain reliable length information # dataset.sample_ids = dataset.valid_ids[ # worker_id * split_size : (worker_id + 1) * split_size # ] current_id = np.random.choice(len(np.random.get_state()[1]), 1) return np.random.seed(np.random.get_state()[1][current_id] + worker_id) return np.random.seed(np.random.get_state()[1][0] + worker_id) class DataModuleFromConfig(pl.LightningDataModule): def __init__( self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, shuffle_val_dataloader=False, num_val_workers=0, ): super().__init__() self.batch_size = batch_size self.dataset_configs = {} self.num_workers = num_workers if num_workers is not None else batch_size * 2 if num_val_workers is None: self.num_val_workers = self.num_workers else: self.num_val_workers = num_val_workers self.use_worker_init_fn = use_worker_init_fn if train is not None: self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: self.dataset_configs["validation"] = validation self.val_dataloader = partial( self._val_dataloader, shuffle=shuffle_val_dataloader ) if test is not None: self.dataset_configs["test"] = test self.test_dataloader = partial( self._test_dataloader, shuffle=shuffle_test_loader ) if predict is not None: self.dataset_configs["predict"] = predict self.predict_dataloader = self._predict_dataloader self.wrap = wrap self.datasets = None def prepare_data(self): for data_cfg in self.dataset_configs.values(): instantiate_from_config(data_cfg) def setup(self, stage=None): self.datasets = { k: instantiate_from_config(c) for k, c in self.dataset_configs.items() } if self.wrap: self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()} def _train_dataloader(self): is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) if is_iterable_dataset or self.use_worker_init_fn: pass else: pass return DataLoader( self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, worker_init_fn=worker_init_fn, ) def _val_dataloader(self, shuffle=False): if ( isinstance(self.datasets["validation"], SingleConceptDataset) or self.use_worker_init_fn ): init_fn = worker_init_fn else: init_fn = None return DataLoader( self.datasets["validation"], batch_size=self.batch_size, num_workers=self.num_val_workers, worker_init_fn=init_fn, shuffle=shuffle, ) def _test_dataloader(self, shuffle=False): is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) if is_iterable_dataset or self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None is_iterable_dataset = False # do not shuffle dataloader for iterable dataset shuffle = shuffle and (not is_iterable_dataset) return DataLoader( self.datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle, ) def _predict_dataloader(self, shuffle=False): if ( isinstance(self.datasets["predict"], SingleConceptDataset) or self.use_worker_init_fn ): init_fn = worker_init_fn else: init_fn = None return DataLoader( self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, ) class SetupCallback(Callback): def __init__( self, resume, now, logdir, ckptdir, cfgdir, ): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir def on_keyboard_interrupt(self, trainer, pl_module): if trainer.global_rank == 0: mod_logger.info("Stopping execution and saving final checkpoint.") ckpt_path = os.path.join(self.ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) else: # ModelCheckpoint callback created log directory --- remove it if not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass class ImageLogger(Callback): def __init__( self, batch_frequency, max_images, clamp=True, increase_log_steps=True, rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, log_images_kwargs=None, log_all_val=False, concept_label=None, ): super().__init__() self.rescale = rescale self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = {} self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp self.disabled = disabled self.log_on_batch_idx = log_on_batch_idx self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step self.log_all_val = log_all_val self.concept_label = concept_label @rank_zero_only def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): root = os.path.join(save_dir, "logs", "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = ( f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png" ) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): # always generate the concept label batch["txt"][0] = self.concept_label check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step if self.log_all_val and split == "val": should_log = True else: should_log = self.check_frequency(check_idx) if ( should_log and (batch_idx % self.batch_freq == 0) and hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0 ): logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() with torch.no_grad(): images = pl_module.log_images( batch, split=split, **self.log_images_kwargs ) for k in images: N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1.0, 1.0) self.log_local( pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx, ) logger_log_images = self.logger_log_images.get( logger, lambda *args, **kwargs: None ) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, check_idx): if (check_idx % self.batch_freq) == 0 and ( check_idx > 0 or self.log_first_step ): return True return False def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") def on_validation_batch_end( self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") if hasattr(pl_module, "calibrate_grad_norm"): if ( pl_module.calibrate_grad_norm and batch_idx % 25 == 0 ) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) class CUDACallback(Callback): # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py def on_train_epoch_start(self, trainer, pl_module): # Reset the memory use counter if "cuda" in get_device(): torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) torch.cuda.synchronize(trainer.strategy.root_device.index) self.start_time = time.time() # noqa def on_train_epoch_end(self, trainer, pl_module): # noqa if "cuda" in get_device(): torch.cuda.synchronize(trainer.strategy.root_device.index) max_memory = ( torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 ) epoch_time = time.time() - self.start_time try: max_memory = trainer.training_type_plugin.reduce(max_memory) epoch_time = trainer.training_type_plugin.reduce(epoch_time) rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") except AttributeError: pass def train_diffusion_model( concept_label, concept_images_dir, class_label, class_images_dir, weights_location=config.DEFAULT_MODEL, logdir="logs", learning_rate=1e-6, accumulate_grad_batches=32, resume=None, ): """ Train a diffusion model on a single concept. accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf """ if DDPStrategy is None: raise ImportError("Please install pytorch-lightning>=1.6.0 to train a model") batch_size = 1 seed = 23 num_workers = 1 num_val_workers = 0 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") logdir = os.path.join(logdir, now) ckpt_output_dir = os.path.join(logdir, "checkpoints") cfg_output_dir = os.path.join(logdir, "configs") seed_everything(seed) model = get_diffusion_model( # noqa weights_location=weights_location, half_mode=False, for_training=True )._model model.learning_rate = learning_rate * accumulate_grad_batches * batch_size # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "imaginairy.train.SetupCallback", "params": { "resume": False, "now": now, "logdir": logdir, "ckptdir": ckpt_output_dir, "cfgdir": cfg_output_dir, }, }, "image_logger": { "target": "imaginairy.train.ImageLogger", "params": { "batch_frequency": 10, "max_images": 1, "clamp": True, "increase_log_steps": False, "log_first_step": True, "log_all_val": True, "concept_label": concept_label, "log_images_kwargs": { "use_ema_scope": True, "inpaint": False, "plot_progressive_rows": False, "plot_diffusion_rows": False, "N": 1, "unconditional_guidance_scale:": 7.5, "unconditional_guidance_label": [""], "ddim_steps": 20, }, }, }, "learning_rate_logger": { "target": "imaginairy.train.LearningRateMonitor", "params": { "logging_interval": "step", # "log_momentum": True }, }, "cuda_callback": {"target": "imaginairy.train.CUDACallback"}, } default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckpt_output_dir, "filename": "{epoch:06}", "verbose": True, "save_last": True, "every_n_train_steps": 50, "save_top_k": -1, "monitor": None, }, } modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg) default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) callbacks_cfg = OmegaConf.create(default_callbacks_cfg) dataset_config = { "concept_label": concept_label, "concept_images_dir": concept_images_dir, "class_label": class_label, "class_images_dir": class_images_dir, "image_transforms": [ { "target": "torchvision.transforms.Resize", "params": {"size": 512, "interpolation": 3}, }, {"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}}, ], } data_module_config = { "batch_size": batch_size, "num_workers": num_workers, "num_val_workers": num_val_workers, "train": { "target": "imaginairy.training_tools.single_concept.SingleConceptDataset", "params": dataset_config, }, } trainer = Trainer( benchmark=True, num_sanity_val_steps=0, accumulate_grad_batches=accumulate_grad_batches, strategy=DDPStrategy(), callbacks=[ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa ], gpus=1, default_root_dir=".", ) trainer.logdir = logdir data = DataModuleFromConfig(**data_module_config) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() def melk(*args, **kwargs): if trainer.global_rank == 0: mod_logger.info("Summoning checkpoint.") ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt") trainer.save_checkpoint(ckpt_path) signal.signal(signal.SIGUSR1, melk) try: try: trainer.fit(model, data) except Exception: melk() raise finally: mod_logger.info(trainer.profiler.summary())