refactor: remove training code

This commit is contained in:
Bryce 2023-12-20 12:28:07 -08:00 committed by Bryce Drennan
parent 616f686ed2
commit 372453e645
7 changed files with 1 additions and 1135 deletions

View File

@ -32,6 +32,7 @@
- 3d diffusion https://huggingface.co/stabilityai/stable-zero123
- magic animate
- consistency decoder
- https://github.com/XPixelGroup/HAT
### Old Todo

View File

@ -1,263 +0,0 @@
"""CLI commands for model training and image preparation"""
import logging
import click
from imaginairy import config
logger = logging.getLogger(__name__)
@click.option(
"--concept-label",
help=(
'The concept you are training on. Usually "a photo of [person or thing] [classname]" is what you should use.'
),
required=True,
)
@click.option(
"--concept-images-dir",
type=click.Path(),
required=True,
help="Where to find the pre-processed concept images to train on.",
)
@click.option(
"--class-label",
help=(
'What class of things does the concept belong to. For example, if you are training on "a painting of a George Washington", '
'you might use "a painting of a man" as the class label. We use this to prevent the model from overfitting.'
),
default="a photo of *",
)
@click.option(
"--class-images-dir",
type=click.Path(),
required=True,
help="Where to find the pre-processed class images to train on.",
)
@click.option(
"--n-class-images",
type=int,
default=300,
help="Number of class images to generate.",
)
@click.option(
"--model-weights-path",
"--model",
"model",
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default=config.DEFAULT_MODEL_WEIGHTS,
)
@click.option(
"--person",
"is_person",
is_flag=True,
help="Set if images are of a person. Will use face detection and enhancement.",
)
@click.option(
"-y",
"preconfirmed",
is_flag=True,
default=False,
help="Bypass input confirmations.",
)
@click.option(
"--skip-prep",
is_flag=True,
default=False,
help="Skip the image preparation step.",
)
@click.option(
"--skip-class-img-gen",
is_flag=True,
default=False,
help="Skip the class image generation step.",
)
@click.command("train-concept")
def train_concept_cmd(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
n_class_images,
model,
is_person,
preconfirmed,
skip_prep,
skip_class_img_gen,
):
"""
Teach the model a new concept (a person, thing, style, etc).
Provided a directory of concept images, a concept token, and a class token, this command will train the model
to generate images of that concept.
\b
This happens in a 3-step process:
1. Cropping and resizing your training images. If --person is set we crop to include the face.
2. Generating a set of class images to train on. This helps prevent overfitting.
3. Training the model on the concept and class images.
The output of this command is a new model weights file that you can use with the --model option.
\b
## Instructions
1. Gather a set of images of the concept you want to train on. They should show the subject from a variety of angles
and in a variety of situations.
2. Train the model.
- Concept label: For a person, firstnamelastname should be fine.
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
or "woman".
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
- CLass images will be generated for you if you do not provide them.
3. Stop training before it overfits. I haven't figured this out yet.
For example, if you were training on photos of a man named bill hamilton you could run the following:
\b
aimg train-concept \\
--person \\
--concept-label "photo of billhamilton man" \\
--concept-images-dir ./images/billhamilton \\
--class-label "photo of a man" \\
--class-images-dir ./images/man
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
You can find a lot of relevant instructions here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
"""
from imaginairy.utils import get_device
if "mps" in get_device():
click.secho(
"⚠️ MPS (MacOS) is not supported for training. Please use a GPU or CPU.",
fg="yellow",
)
return
import os.path
from imaginairy.training_tools.image_prep import (
create_class_images,
get_image_filenames,
prep_images,
)
from imaginairy.training_tools.train import train_diffusion_model
target_size = 512
# Step 1. Crop and enhance the training images
prepped_images_path = os.path.join(concept_images_dir, "prepped-images")
image_filenames = get_image_filenames(concept_images_dir)
click.secho(
f'\n🤖🧠 Training "{concept_label}" based on {len(image_filenames)} images.\n'
)
if not skip_prep:
msg = (
f"Creating cropped copies of the {len(image_filenames)} concept images\n"
f" Is Person: {is_person}\n"
f" Source: {concept_images_dir}\n"
f" Dest: {prepped_images_path}\n"
)
logger.info(msg)
if not is_person:
click.secho("⚠️ the --person flag was not set. ", fg="yellow")
if not preconfirmed and not click.confirm("Continue?"):
return
prep_images(
images_dir=concept_images_dir, is_person=is_person, target_size=target_size
)
concept_images_dir = prepped_images_path
if not skip_class_img_gen:
# Step 2. Generate class images
class_image_filenames = get_image_filenames(class_images_dir)
images_needed = max(n_class_images - len(class_image_filenames), 0)
logger.info(f"Generating {n_class_images} class images in {class_images_dir}")
logger.info(
f"{len(class_image_filenames)} existing class images found so only generating {images_needed}."
)
if not preconfirmed and not click.confirm("Continue?"):
return
create_class_images(
class_description=class_label,
output_folder=class_images_dir,
num_images=n_class_images,
)
logger.info("Training the model...")
if not preconfirmed and not click.confirm("Continue?"):
return
# Step 3. Train the model
train_diffusion_model(
concept_label=concept_label,
concept_images_dir=concept_images_dir,
class_label=class_label,
class_images_dir=class_images_dir,
weights_location=model,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
)
@click.argument("ckpt_paths", nargs=-1)
@click.command("prune-ckpt")
def prune_ckpt_cmd(ckpt_paths):
"""
Prune a checkpoint file.
This will remove the optimizer state from the checkpoint file.
This is useful if you want to use the checkpoint file for inference and save a lot of disk space
Example:
aimg prune-ckpt ./path/to/checkpoint.ckpt
"""
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
click.secho("Pruning checkpoint files...")
for p in ckpt_paths:
prune_diffusion_ckpt(p)
@click.argument(
"images_dir",
required=True,
)
@click.option(
"--person",
"is_person",
is_flag=True,
help="Set if images are of a person. Will use face detection and enhancement.",
)
@click.option(
"--target-size",
default=512,
type=int,
show_default=True,
)
@click.command("prep-images")
def prep_images_cmd(images_dir, is_person, target_size):
"""
Prepare a folder of images for training.
Prepped images will be written to the `prepped-images` subfolder.
All images will be cropped and resized to (default) 512x512.
Upscaling and face enhancement will be applied as needed to smaller images.
Examples:
aimg prep-images --person ./images/selfies
aimg prep-images ./images/toy-train
"""
from imaginairy.training_tools.image_prep import prep_images
prep_images(images_dir=images_dir, is_person=is_person, target_size=target_size)

View File

@ -1,158 +0,0 @@
"""Functions for image preprocessing and generation"""
import logging
import os
import os.path
import re
from PIL import Image
from tqdm import tqdm
from imaginairy.api import imagine
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
from imaginairy.vendored.smart_crop import SmartCrop
logger = logging.getLogger(__name__)
def get_image_filenames(folder):
filenames = []
for filename in os.listdir(folder):
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
continue
if filename.startswith("."):
continue
filenames.append(filename)
return filenames
def prep_images(
images_dir, is_person=False, output_folder_name="prepped-images", target_size=512
):
"""
Crops and resizes a directory of images in preparation for training.
If is_person=True, it will detect the face and produces several crops at different zoom levels. For crops that
are too small, it will use the face restoration model to enhance the faces.
For non-person images, it will use the smartcrop algorithm to crop the image to the most interesting part. If the
input image is too small it will be upscaled.
Prep will go a lot faster if all the images are big enough to not require upscaling.
"""
output_folder = os.path.join(images_dir, output_folder_name)
os.makedirs(output_folder, exist_ok=True)
logger.info(f"Prepping images in {images_dir} to {output_folder}")
image_filenames = get_image_filenames(images_dir)
pbar = tqdm(image_filenames)
for filename in pbar:
pbar.set_description(filename)
input_path = os.path.join(images_dir, filename)
img = LazyLoadingImage(filepath=input_path).convert("RGB")
if is_person:
face_rois = detect_faces(img)
if len(face_rois) == 0:
logger.info(f"No faces detected in image {filename}, skipping")
continue
if len(face_rois) > 1:
logger.info(f"Multiple faces detected in image {filename}, skipping")
continue
face_roi = face_rois[0]
face_roi_crops = generate_face_crops(
face_roi, max_width=img.width, max_height=img.height
)
for n, face_roi_crop in enumerate(face_roi_crops):
cropped_output_path = os.path.join(
output_folder, f"{filename}_[alt-{n:02d}].jpg"
)
if os.path.exists(cropped_output_path):
logger.debug(
f"Skipping {cropped_output_path} because it already exists"
)
continue
x1, y1, x2, y2 = face_roi_crop
crop_width = x2 - x1
crop_height = y2 - y1
if crop_width != crop_height:
logger.info(
f"Face ROI crop for {filename} {crop_width}x{crop_height} is not square, skipping"
)
continue
cropped_img = img.crop(face_roi_crop)
if crop_width < target_size:
logger.info(f"Upscaling {filename} {face_roi_crop}")
cropped_img = cropped_img.resize(
(target_size, target_size), resample=Image.Resampling.LANCZOS
)
cropped_img = enhance_faces(cropped_img, fidelity=1)
else:
cropped_img = cropped_img.resize(
(target_size, target_size), resample=Image.Resampling.LANCZOS
)
cropped_img.save(cropped_output_path, quality=95)
else:
# scale image so that largest dimension is target_size
n = 0
cropped_output_path = os.path.join(output_folder, f"{filename}_{n}.jpg")
if os.path.exists(cropped_output_path):
logger.debug(
f"Skipping {cropped_output_path} because it already exists"
)
continue
if img.width < target_size or img.height < target_size:
# upscale the image if it's too small
logger.info(f"Upscaling {filename}")
img = upscale_image(img)
if img.width > img.height:
scale_factor = target_size / img.height
else:
scale_factor = target_size / img.width
# downscale so shortest side is target_size
new_width = int(round(img.width * scale_factor))
new_height = int(round(img.height * scale_factor))
img = img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
result = SmartCrop().crop(img, width=target_size, height=target_size)
box = (
result["top_crop"]["x"],
result["top_crop"]["y"],
result["top_crop"]["width"] + result["top_crop"]["x"],
result["top_crop"]["height"] + result["top_crop"]["y"],
)
cropped_image = img.crop(box)
cropped_image.save(cropped_output_path, quality=95)
logger.info(f"Image Prep complete. Review output at {output_folder}")
def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "-", prompt)[:130]
def create_class_images(class_description, output_folder, num_images=200):
"""
Generate images of class_description.
"""
existing_images = get_image_filenames(output_folder)
existing_image_count = len(existing_images)
class_slug = prompt_normalized(class_description)
while existing_image_count < num_images:
prompt = ImaginePrompt(class_description, steps=20)
result = next(iter(imagine([prompt])))
if result.is_nsfw:
continue
dest = os.path.join(
output_folder, f"{existing_image_count:03d}_{class_slug}.jpg"
)
result.save(dest)
existing_image_count += 1

View File

@ -1,41 +0,0 @@
"""Functions for pruning diffusion models"""
import logging
import os
import torch
logger = logging.getLogger(__name__)
def prune_diffusion_ckpt(ckpt_path, dst_path=None):
if dst_path is None:
dst_path = f"{os.path.splitext(ckpt_path)[0]}-pruned.ckpt"
data = torch.load(ckpt_path, map_location="cpu")
new_data = prune_model_data(data)
torch.save(new_data, dst_path)
size_initial = os.path.getsize(ckpt_path)
newsize = os.path.getsize(dst_path)
msg = (
f"New ckpt size: {newsize * 1e-9:.2f} GB. "
f"Saved {(size_initial - newsize) * 1e-9:.2f} GB by removing optimizer states"
)
logger.info(msg)
def prune_model_data(data, only_keep_ema=True):
data.pop("optimizer_states", None)
if only_keep_ema:
state_dict = data["state_dict"]
model_keys = [k for k in state_dict if k.startswith("model.")]
for model_key in model_keys:
ema_key = "model_ema." + model_key[6:].replace(".", "")
state_dict[model_key] = state_dict[ema_key]
del state_dict[ema_key]
return data

View File

@ -1,138 +0,0 @@
"""Classes for single-concept model finetuning"""
import os
import random
import re
from abc import abstractmethod
from collections import defaultdict
from einops import rearrange
from omegaconf import ListConfig
from PIL import Image
from torch.utils.data import Dataset, IterableDataset
from torchvision.transforms import transforms
from imaginairy.utils import instantiate_from_config
class Txt2ImgIterableBaseDataset(IterableDataset):
"""
Define an interface to make the IterableDatasets for text2img data chainable.
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
def _rearrange(x):
return rearrange(x * 2.0 - 1.0, "c h w -> h w c")
class SingleConceptDataset(Dataset):
"""
Dataset for finetuning a model on a single concept.
Similar to "dreambooth"
"""
def __init__(
self,
concept_label,
class_label,
concept_images_dir,
class_images_dir,
image_transforms=None,
):
self.concept_label = concept_label
self.class_label = class_label
self.concept_images_dir = concept_images_dir
self.class_images_dir = class_images_dir
if isinstance(image_transforms, (ListConfig, list)):
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend(
[
transforms.ToTensor(),
transforms.Lambda(_rearrange),
]
)
image_transforms = transforms.Compose(image_transforms)
self.image_transforms = image_transforms
self._concept_image_filename_groups = None
self._class_image_filenames = None
def __len__(self):
return len(self.concept_image_filename_groups) * 2
def __getitem__(self, idx):
if idx % 2:
img_group = self._concept_image_filename_groups[int(idx / 2)]
img_filename = random.choice(img_group)
img_path = os.path.join(self.concept_images_dir, img_filename)
txt = self.concept_label
else:
img_path = os.path.join(
self.class_images_dir, self.class_image_filenames[int(idx / 2)]
)
txt = self.class_label
try:
image = Image.open(img_path).convert("RGB")
except RuntimeError as e:
msg = f"Could not read image {img_path}"
raise RuntimeError(msg) from e
image = self.image_transforms(image)
data = {"image": image, "txt": txt}
return data
@property
def concept_image_filename_groups(self):
if self._concept_image_filename_groups is None:
self._concept_image_filename_groups = _load_image_filenames_and_alts(
self.concept_images_dir
)
return self._concept_image_filename_groups
@property
def class_image_filenames(self):
if self._class_image_filenames is None:
self._class_image_filenames = _load_image_filenames(self.class_images_dir)
return self._class_image_filenames
@property
def num_records(self):
return len(self)
def _load_image_filenames_and_alts(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
"""Loads images into groups (filenames tagged with `[alt-{n:02d}]` are grouped together)."""
image_filenames = _load_image_filenames(img_dir, image_extensions)
grouped_img_filenames = defaultdict(list)
for filename in image_filenames:
base_filename = re.sub(r"\[alt-\d*\]", "", filename)
grouped_img_filenames[base_filename].append(filename)
return list(grouped_img_filenames.values())
def _load_image_filenames(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
image_filenames = []
for filename in os.listdir(img_dir):
if filename.lower().endswith(image_extensions) and not filename.startswith("."):
image_filenames.append(filename)
random.shuffle(image_filenames)
return image_filenames

View File

@ -1,535 +0,0 @@
"""Code for training diffusion models"""
import datetime
import logging
import os
import signal
import time
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None # type: ignore
import contextlib
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.utils.data import DataLoader, Dataset
from imaginairy import config
from imaginairy.training_tools.single_concept import SingleConceptDataset
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.utils.model_manager import get_diffusion_model
mod_logger = logging.getLogger(__name__)
referenced_by_string = [LearningRateMonitor]
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, SingleConceptDataset):
# split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[
# worker_id * split_size : (worker_id + 1) * split_size
# ]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
num_val_workers=0,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = {}
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if num_val_workers is None:
self.num_val_workers = self.num_workers
else:
self.num_val_workers = num_val_workers
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
self.datasets = None
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
}
if self.wrap:
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
pass
else:
pass
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=worker_init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["validation"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_val_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
is_iterable_dataset = False
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["predict"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
mod_logger.info("Stopping execution and saving final checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst)
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_all_val=False,
concept_label=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_all_val = log_all_val
self.concept_label = concept_label
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "logs", "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = (
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
# always generate the concept label
batch["txt"][0] = self.concept_label
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if self.log_all_val and split == "val":
should_log = True
else:
should_log = self.check_frequency(check_idx)
if (
should_log
and (batch_idx % self.batch_freq == 0)
and hasattr(pl_module, "log_images")
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (check_idx % self.batch_freq) == 0 and (
check_idx > 0 or self.log_first_step
):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if (
hasattr(pl_module, "calibrate_grad_norm")
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
and batch_idx > 0
):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
/ 2**20
)
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
def train_diffusion_model(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
weights_location=config.DEFAULT_MODEL_WEIGHTS,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
resume=None,
):
"""
Train a diffusion model on a single concept.
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
"""
if DDPStrategy is None:
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "imaginairy.train.SetupCallback",
"params": {
"resume": False,
"now": now,
"logdir": logdir,
"ckptdir": ckpt_output_dir,
"cfgdir": cfg_output_dir,
},
},
"image_logger": {
"target": "imaginairy.train.ImageLogger",
"params": {
"batch_frequency": 10,
"max_images": 1,
"clamp": True,
"increase_log_steps": False,
"log_first_step": True,
"log_all_val": True,
"concept_label": concept_label,
"log_images_kwargs": {
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 1,
"unconditional_guidance_scale:": 7.5,
"unconditional_guidance_label": [""],
"ddim_steps": 20,
},
},
},
"learning_rate_logger": {
"target": "imaginairy.train.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
}
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckpt_output_dir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 50,
"save_top_k": -1,
"monitor": None,
},
}
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
dataset_config = {
"concept_label": concept_label,
"concept_images_dir": concept_images_dir,
"class_label": class_label,
"class_images_dir": class_images_dir,
"image_transforms": [
{
"target": "torchvision.transforms.Resize",
"params": {"size": 512, "interpolation": 3},
},
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
],
}
data_module_config = {
"batch_size": batch_size,
"num_workers": num_workers,
"num_val_workers": num_val_workers,
"train": {
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
"params": dataset_config,
},
}
trainer = Trainer(
benchmark=True,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
gpus=1,
default_root_dir=".",
)
trainer.logdir = logdir
data = DataModuleFromConfig(**data_module_config)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
def melk(*args, **kwargs):
if trainer.global_rank == 0:
mod_logger.info("Summoning checkpoint.")
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
signal.signal(signal.SIGUSR1, melk)
try:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
finally:
mod_logger.info(trainer.profiler.summary())