mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
refactor: remove training code
This commit is contained in:
parent
616f686ed2
commit
372453e645
@ -32,6 +32,7 @@
|
||||
- 3d diffusion https://huggingface.co/stabilityai/stable-zero123
|
||||
- magic animate
|
||||
- consistency decoder
|
||||
- https://github.com/XPixelGroup/HAT
|
||||
|
||||
### Old Todo
|
||||
|
||||
|
@ -1,263 +0,0 @@
|
||||
"""CLI commands for model training and image preparation"""
|
||||
|
||||
import logging
|
||||
|
||||
import click
|
||||
|
||||
from imaginairy import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@click.option(
|
||||
"--concept-label",
|
||||
help=(
|
||||
'The concept you are training on. Usually "a photo of [person or thing] [classname]" is what you should use.'
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--concept-images-dir",
|
||||
type=click.Path(),
|
||||
required=True,
|
||||
help="Where to find the pre-processed concept images to train on.",
|
||||
)
|
||||
@click.option(
|
||||
"--class-label",
|
||||
help=(
|
||||
'What class of things does the concept belong to. For example, if you are training on "a painting of a George Washington", '
|
||||
'you might use "a painting of a man" as the class label. We use this to prevent the model from overfitting.'
|
||||
),
|
||||
default="a photo of *",
|
||||
)
|
||||
@click.option(
|
||||
"--class-images-dir",
|
||||
type=click.Path(),
|
||||
required=True,
|
||||
help="Where to find the pre-processed class images to train on.",
|
||||
)
|
||||
@click.option(
|
||||
"--n-class-images",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Number of class images to generate.",
|
||||
)
|
||||
@click.option(
|
||||
"--model-weights-path",
|
||||
"--model",
|
||||
"model",
|
||||
help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.",
|
||||
show_default=True,
|
||||
default=config.DEFAULT_MODEL_WEIGHTS,
|
||||
)
|
||||
@click.option(
|
||||
"--person",
|
||||
"is_person",
|
||||
is_flag=True,
|
||||
help="Set if images are of a person. Will use face detection and enhancement.",
|
||||
)
|
||||
@click.option(
|
||||
"-y",
|
||||
"preconfirmed",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Bypass input confirmations.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip-prep",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Skip the image preparation step.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip-class-img-gen",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Skip the class image generation step.",
|
||||
)
|
||||
@click.command("train-concept")
|
||||
def train_concept_cmd(
|
||||
concept_label,
|
||||
concept_images_dir,
|
||||
class_label,
|
||||
class_images_dir,
|
||||
n_class_images,
|
||||
model,
|
||||
is_person,
|
||||
preconfirmed,
|
||||
skip_prep,
|
||||
skip_class_img_gen,
|
||||
):
|
||||
"""
|
||||
Teach the model a new concept (a person, thing, style, etc).
|
||||
|
||||
Provided a directory of concept images, a concept token, and a class token, this command will train the model
|
||||
to generate images of that concept.
|
||||
|
||||
\b
|
||||
This happens in a 3-step process:
|
||||
1. Cropping and resizing your training images. If --person is set we crop to include the face.
|
||||
2. Generating a set of class images to train on. This helps prevent overfitting.
|
||||
3. Training the model on the concept and class images.
|
||||
|
||||
The output of this command is a new model weights file that you can use with the --model option.
|
||||
|
||||
\b
|
||||
## Instructions
|
||||
1. Gather a set of images of the concept you want to train on. They should show the subject from a variety of angles
|
||||
and in a variety of situations.
|
||||
2. Train the model.
|
||||
- Concept label: For a person, firstnamelastname should be fine.
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
|
||||
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
|
||||
or "woman".
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
|
||||
- CLass images will be generated for you if you do not provide them.
|
||||
3. Stop training before it overfits. I haven't figured this out yet.
|
||||
|
||||
|
||||
For example, if you were training on photos of a man named bill hamilton you could run the following:
|
||||
|
||||
\b
|
||||
aimg train-concept \\
|
||||
--person \\
|
||||
--concept-label "photo of billhamilton man" \\
|
||||
--concept-images-dir ./images/billhamilton \\
|
||||
--class-label "photo of a man" \\
|
||||
--class-images-dir ./images/man
|
||||
|
||||
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
|
||||
|
||||
You can find a lot of relevant instructions here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
|
||||
"""
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
if "mps" in get_device():
|
||||
click.secho(
|
||||
"⚠️ MPS (MacOS) is not supported for training. Please use a GPU or CPU.",
|
||||
fg="yellow",
|
||||
)
|
||||
return
|
||||
|
||||
import os.path
|
||||
|
||||
from imaginairy.training_tools.image_prep import (
|
||||
create_class_images,
|
||||
get_image_filenames,
|
||||
prep_images,
|
||||
)
|
||||
from imaginairy.training_tools.train import train_diffusion_model
|
||||
|
||||
target_size = 512
|
||||
# Step 1. Crop and enhance the training images
|
||||
prepped_images_path = os.path.join(concept_images_dir, "prepped-images")
|
||||
image_filenames = get_image_filenames(concept_images_dir)
|
||||
click.secho(
|
||||
f'\n🤖🧠 Training "{concept_label}" based on {len(image_filenames)} images.\n'
|
||||
)
|
||||
|
||||
if not skip_prep:
|
||||
msg = (
|
||||
f"Creating cropped copies of the {len(image_filenames)} concept images\n"
|
||||
f" Is Person: {is_person}\n"
|
||||
f" Source: {concept_images_dir}\n"
|
||||
f" Dest: {prepped_images_path}\n"
|
||||
)
|
||||
logger.info(msg)
|
||||
if not is_person:
|
||||
click.secho("⚠️ the --person flag was not set. ", fg="yellow")
|
||||
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
|
||||
prep_images(
|
||||
images_dir=concept_images_dir, is_person=is_person, target_size=target_size
|
||||
)
|
||||
concept_images_dir = prepped_images_path
|
||||
|
||||
if not skip_class_img_gen:
|
||||
# Step 2. Generate class images
|
||||
class_image_filenames = get_image_filenames(class_images_dir)
|
||||
images_needed = max(n_class_images - len(class_image_filenames), 0)
|
||||
logger.info(f"Generating {n_class_images} class images in {class_images_dir}")
|
||||
logger.info(
|
||||
f"{len(class_image_filenames)} existing class images found so only generating {images_needed}."
|
||||
)
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
create_class_images(
|
||||
class_description=class_label,
|
||||
output_folder=class_images_dir,
|
||||
num_images=n_class_images,
|
||||
)
|
||||
|
||||
logger.info("Training the model...")
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
|
||||
# Step 3. Train the model
|
||||
train_diffusion_model(
|
||||
concept_label=concept_label,
|
||||
concept_images_dir=concept_images_dir,
|
||||
class_label=class_label,
|
||||
class_images_dir=class_images_dir,
|
||||
weights_location=model,
|
||||
logdir="logs",
|
||||
learning_rate=1e-6,
|
||||
accumulate_grad_batches=32,
|
||||
)
|
||||
|
||||
|
||||
@click.argument("ckpt_paths", nargs=-1)
|
||||
@click.command("prune-ckpt")
|
||||
def prune_ckpt_cmd(ckpt_paths):
|
||||
"""
|
||||
Prune a checkpoint file.
|
||||
|
||||
This will remove the optimizer state from the checkpoint file.
|
||||
This is useful if you want to use the checkpoint file for inference and save a lot of disk space
|
||||
|
||||
Example:
|
||||
aimg prune-ckpt ./path/to/checkpoint.ckpt
|
||||
"""
|
||||
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
|
||||
|
||||
click.secho("Pruning checkpoint files...")
|
||||
for p in ckpt_paths:
|
||||
prune_diffusion_ckpt(p)
|
||||
|
||||
|
||||
@click.argument(
|
||||
"images_dir",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--person",
|
||||
"is_person",
|
||||
is_flag=True,
|
||||
help="Set if images are of a person. Will use face detection and enhancement.",
|
||||
)
|
||||
@click.option(
|
||||
"--target-size",
|
||||
default=512,
|
||||
type=int,
|
||||
show_default=True,
|
||||
)
|
||||
@click.command("prep-images")
|
||||
def prep_images_cmd(images_dir, is_person, target_size):
|
||||
"""
|
||||
Prepare a folder of images for training.
|
||||
|
||||
Prepped images will be written to the `prepped-images` subfolder.
|
||||
|
||||
All images will be cropped and resized to (default) 512x512.
|
||||
Upscaling and face enhancement will be applied as needed to smaller images.
|
||||
|
||||
Examples:
|
||||
aimg prep-images --person ./images/selfies
|
||||
aimg prep-images ./images/toy-train
|
||||
"""
|
||||
|
||||
from imaginairy.training_tools.image_prep import prep_images
|
||||
|
||||
prep_images(images_dir=images_dir, is_person=is_person, target_size=target_size)
|
@ -1,158 +0,0 @@
|
||||
"""Functions for image preprocessing and generation"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy.api import imagine
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.schema import ImaginePrompt, LazyLoadingImage
|
||||
from imaginairy.vendored.smart_crop import SmartCrop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_filenames(folder):
|
||||
filenames = []
|
||||
for filename in os.listdir(folder):
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
|
||||
continue
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def prep_images(
|
||||
images_dir, is_person=False, output_folder_name="prepped-images", target_size=512
|
||||
):
|
||||
"""
|
||||
Crops and resizes a directory of images in preparation for training.
|
||||
|
||||
If is_person=True, it will detect the face and produces several crops at different zoom levels. For crops that
|
||||
are too small, it will use the face restoration model to enhance the faces.
|
||||
|
||||
For non-person images, it will use the smartcrop algorithm to crop the image to the most interesting part. If the
|
||||
input image is too small it will be upscaled.
|
||||
|
||||
Prep will go a lot faster if all the images are big enough to not require upscaling.
|
||||
|
||||
"""
|
||||
output_folder = os.path.join(images_dir, output_folder_name)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
logger.info(f"Prepping images in {images_dir} to {output_folder}")
|
||||
image_filenames = get_image_filenames(images_dir)
|
||||
pbar = tqdm(image_filenames)
|
||||
for filename in pbar:
|
||||
pbar.set_description(filename)
|
||||
|
||||
input_path = os.path.join(images_dir, filename)
|
||||
img = LazyLoadingImage(filepath=input_path).convert("RGB")
|
||||
if is_person:
|
||||
face_rois = detect_faces(img)
|
||||
if len(face_rois) == 0:
|
||||
logger.info(f"No faces detected in image {filename}, skipping")
|
||||
continue
|
||||
if len(face_rois) > 1:
|
||||
logger.info(f"Multiple faces detected in image {filename}, skipping")
|
||||
continue
|
||||
face_roi = face_rois[0]
|
||||
face_roi_crops = generate_face_crops(
|
||||
face_roi, max_width=img.width, max_height=img.height
|
||||
)
|
||||
for n, face_roi_crop in enumerate(face_roi_crops):
|
||||
cropped_output_path = os.path.join(
|
||||
output_folder, f"{filename}_[alt-{n:02d}].jpg"
|
||||
)
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
x1, y1, x2, y2 = face_roi_crop
|
||||
crop_width = x2 - x1
|
||||
crop_height = y2 - y1
|
||||
if crop_width != crop_height:
|
||||
logger.info(
|
||||
f"Face ROI crop for {filename} {crop_width}x{crop_height} is not square, skipping"
|
||||
)
|
||||
continue
|
||||
cropped_img = img.crop(face_roi_crop)
|
||||
|
||||
if crop_width < target_size:
|
||||
logger.info(f"Upscaling {filename} {face_roi_crop}")
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img = enhance_faces(cropped_img, fidelity=1)
|
||||
else:
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img.save(cropped_output_path, quality=95)
|
||||
else:
|
||||
# scale image so that largest dimension is target_size
|
||||
n = 0
|
||||
cropped_output_path = os.path.join(output_folder, f"{filename}_{n}.jpg")
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
if img.width < target_size or img.height < target_size:
|
||||
# upscale the image if it's too small
|
||||
logger.info(f"Upscaling {filename}")
|
||||
img = upscale_image(img)
|
||||
|
||||
if img.width > img.height:
|
||||
scale_factor = target_size / img.height
|
||||
else:
|
||||
scale_factor = target_size / img.width
|
||||
|
||||
# downscale so shortest side is target_size
|
||||
new_width = int(round(img.width * scale_factor))
|
||||
new_height = int(round(img.height * scale_factor))
|
||||
img = img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
result = SmartCrop().crop(img, width=target_size, height=target_size)
|
||||
|
||||
box = (
|
||||
result["top_crop"]["x"],
|
||||
result["top_crop"]["y"],
|
||||
result["top_crop"]["width"] + result["top_crop"]["x"],
|
||||
result["top_crop"]["height"] + result["top_crop"]["y"],
|
||||
)
|
||||
|
||||
cropped_image = img.crop(box)
|
||||
cropped_image.save(cropped_output_path, quality=95)
|
||||
logger.info(f"Image Prep complete. Review output at {output_folder}")
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "-", prompt)[:130]
|
||||
|
||||
|
||||
def create_class_images(class_description, output_folder, num_images=200):
|
||||
"""
|
||||
Generate images of class_description.
|
||||
"""
|
||||
existing_images = get_image_filenames(output_folder)
|
||||
existing_image_count = len(existing_images)
|
||||
class_slug = prompt_normalized(class_description)
|
||||
|
||||
while existing_image_count < num_images:
|
||||
prompt = ImaginePrompt(class_description, steps=20)
|
||||
result = next(iter(imagine([prompt])))
|
||||
if result.is_nsfw:
|
||||
continue
|
||||
dest = os.path.join(
|
||||
output_folder, f"{existing_image_count:03d}_{class_slug}.jpg"
|
||||
)
|
||||
result.save(dest)
|
||||
existing_image_count += 1
|
@ -1,41 +0,0 @@
|
||||
"""Functions for pruning diffusion models"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prune_diffusion_ckpt(ckpt_path, dst_path=None):
|
||||
if dst_path is None:
|
||||
dst_path = f"{os.path.splitext(ckpt_path)[0]}-pruned.ckpt"
|
||||
|
||||
data = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
new_data = prune_model_data(data)
|
||||
|
||||
torch.save(new_data, dst_path)
|
||||
|
||||
size_initial = os.path.getsize(ckpt_path)
|
||||
newsize = os.path.getsize(dst_path)
|
||||
msg = (
|
||||
f"New ckpt size: {newsize * 1e-9:.2f} GB. "
|
||||
f"Saved {(size_initial - newsize) * 1e-9:.2f} GB by removing optimizer states"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def prune_model_data(data, only_keep_ema=True):
|
||||
data.pop("optimizer_states", None)
|
||||
if only_keep_ema:
|
||||
state_dict = data["state_dict"]
|
||||
model_keys = [k for k in state_dict if k.startswith("model.")]
|
||||
|
||||
for model_key in model_keys:
|
||||
ema_key = "model_ema." + model_key[6:].replace(".", "")
|
||||
state_dict[model_key] = state_dict[ema_key]
|
||||
del state_dict[ema_key]
|
||||
|
||||
return data
|
@ -1,138 +0,0 @@
|
||||
"""Classes for single-concept model finetuning"""
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from einops import rearrange
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable.
|
||||
"""
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _rearrange(x):
|
||||
return rearrange(x * 2.0 - 1.0, "c h w -> h w c")
|
||||
|
||||
|
||||
class SingleConceptDataset(Dataset):
|
||||
"""
|
||||
Dataset for finetuning a model on a single concept.
|
||||
|
||||
Similar to "dreambooth"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
concept_label,
|
||||
class_label,
|
||||
concept_images_dir,
|
||||
class_images_dir,
|
||||
image_transforms=None,
|
||||
):
|
||||
self.concept_label = concept_label
|
||||
self.class_label = class_label
|
||||
self.concept_images_dir = concept_images_dir
|
||||
self.class_images_dir = class_images_dir
|
||||
|
||||
if isinstance(image_transforms, (ListConfig, list)):
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(_rearrange),
|
||||
]
|
||||
)
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
self._concept_image_filename_groups = None
|
||||
self._class_image_filenames = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.concept_image_filename_groups) * 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx % 2:
|
||||
img_group = self._concept_image_filename_groups[int(idx / 2)]
|
||||
img_filename = random.choice(img_group)
|
||||
img_path = os.path.join(self.concept_images_dir, img_filename)
|
||||
|
||||
txt = self.concept_label
|
||||
else:
|
||||
img_path = os.path.join(
|
||||
self.class_images_dir, self.class_image_filenames[int(idx / 2)]
|
||||
)
|
||||
txt = self.class_label
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
except RuntimeError as e:
|
||||
msg = f"Could not read image {img_path}"
|
||||
raise RuntimeError(msg) from e
|
||||
image = self.image_transforms(image)
|
||||
data = {"image": image, "txt": txt}
|
||||
return data
|
||||
|
||||
@property
|
||||
def concept_image_filename_groups(self):
|
||||
if self._concept_image_filename_groups is None:
|
||||
self._concept_image_filename_groups = _load_image_filenames_and_alts(
|
||||
self.concept_images_dir
|
||||
)
|
||||
return self._concept_image_filename_groups
|
||||
|
||||
@property
|
||||
def class_image_filenames(self):
|
||||
if self._class_image_filenames is None:
|
||||
self._class_image_filenames = _load_image_filenames(self.class_images_dir)
|
||||
return self._class_image_filenames
|
||||
|
||||
@property
|
||||
def num_records(self):
|
||||
return len(self)
|
||||
|
||||
|
||||
def _load_image_filenames_and_alts(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
"""Loads images into groups (filenames tagged with `[alt-{n:02d}]` are grouped together)."""
|
||||
image_filenames = _load_image_filenames(img_dir, image_extensions)
|
||||
grouped_img_filenames = defaultdict(list)
|
||||
for filename in image_filenames:
|
||||
base_filename = re.sub(r"\[alt-\d*\]", "", filename)
|
||||
grouped_img_filenames[base_filename].append(filename)
|
||||
return list(grouped_img_filenames.values())
|
||||
|
||||
|
||||
def _load_image_filenames(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
image_filenames = []
|
||||
for filename in os.listdir(img_dir):
|
||||
if filename.lower().endswith(image_extensions) and not filename.startswith("."):
|
||||
image_filenames.append(filename)
|
||||
random.shuffle(image_filenames)
|
||||
return image_filenames
|
@ -1,535 +0,0 @@
|
||||
"""Code for training diffusion models"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
|
||||
|
||||
try:
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
except ImportError:
|
||||
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
|
||||
# Use >= 1.6.0 to make this work
|
||||
DDPStrategy = None # type: ignore
|
||||
import contextlib
|
||||
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.training_tools.single_concept import SingleConceptDataset
|
||||
from imaginairy.utils import get_device, instantiate_from_config
|
||||
from imaginairy.utils.model_manager import get_diffusion_model
|
||||
|
||||
mod_logger = logging.getLogger(__name__)
|
||||
|
||||
referenced_by_string = [LearningRateMonitor]
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, SingleConceptDataset):
|
||||
# split_size = dataset.num_records // worker_info.num_workers
|
||||
# reset num_records to the true number to retain reliable length information
|
||||
# dataset.sample_ids = dataset.valid_ids[
|
||||
# worker_id * split_size : (worker_id + 1) * split_size
|
||||
# ]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False,
|
||||
num_val_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = {}
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
if num_val_workers is None:
|
||||
self.num_val_workers = self.num_workers
|
||||
else:
|
||||
self.num_val_workers = num_val_workers
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(
|
||||
self._val_dataloader, shuffle=shuffle_val_dataloader
|
||||
)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(
|
||||
self._test_dataloader, shuffle=shuffle_test_loader
|
||||
)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
self.wrap = wrap
|
||||
self.datasets = None
|
||||
|
||||
def prepare_data(self):
|
||||
for data_cfg in self.dataset_configs.values():
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = {
|
||||
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
|
||||
}
|
||||
if self.wrap:
|
||||
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
return DataLoader(
|
||||
self.datasets["train"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["validation"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_val_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
is_iterable_dataset = False
|
||||
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoader(
|
||||
self.datasets["test"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["predict"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
resume,
|
||||
now,
|
||||
logdir,
|
||||
ckptdir,
|
||||
cfgdir,
|
||||
):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.now = now
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Stopping execution and saving final checkpoint.")
|
||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
os.rename(self.logdir, dst)
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
batch_frequency,
|
||||
max_images,
|
||||
clamp=True,
|
||||
increase_log_steps=True,
|
||||
rescale=True,
|
||||
disabled=False,
|
||||
log_on_batch_idx=False,
|
||||
log_first_step=False,
|
||||
log_images_kwargs=None,
|
||||
log_all_val=False,
|
||||
concept_label=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {}
|
||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
self.disabled = disabled
|
||||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
self.log_all_val = log_all_val
|
||||
self.concept_label = concept_label
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "logs", "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = (
|
||||
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
|
||||
)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
# always generate the concept label
|
||||
batch["txt"][0] = self.concept_label
|
||||
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if self.log_all_val and split == "val":
|
||||
should_log = True
|
||||
else:
|
||||
should_log = self.check_frequency(check_idx)
|
||||
if (
|
||||
should_log
|
||||
and (batch_idx % self.batch_freq == 0)
|
||||
and hasattr(pl_module, "log_images")
|
||||
and callable(pl_module.log_images)
|
||||
and self.max_images > 0
|
||||
):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(
|
||||
batch, split=split, **self.log_images_kwargs
|
||||
)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
||||
|
||||
self.log_local(
|
||||
pl_module.logger.save_dir,
|
||||
split,
|
||||
images,
|
||||
pl_module.global_step,
|
||||
pl_module.current_epoch,
|
||||
batch_idx,
|
||||
)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(
|
||||
logger, lambda *args, **kwargs: None
|
||||
)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if (check_idx % self.batch_freq) == 0 and (
|
||||
check_idx > 0 or self.log_first_step
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
if (
|
||||
hasattr(pl_module, "calibrate_grad_norm")
|
||||
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
|
||||
and batch_idx > 0
|
||||
):
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = (
|
||||
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
|
||||
/ 2**20
|
||||
)
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def train_diffusion_model(
|
||||
concept_label,
|
||||
concept_images_dir,
|
||||
class_label,
|
||||
class_images_dir,
|
||||
weights_location=config.DEFAULT_MODEL_WEIGHTS,
|
||||
logdir="logs",
|
||||
learning_rate=1e-6,
|
||||
accumulate_grad_batches=32,
|
||||
resume=None,
|
||||
):
|
||||
"""
|
||||
Train a diffusion model on a single concept.
|
||||
|
||||
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
|
||||
"""
|
||||
if DDPStrategy is None:
|
||||
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
|
||||
raise ImportError(msg)
|
||||
|
||||
batch_size = 1
|
||||
seed = 23
|
||||
num_workers = 1
|
||||
num_val_workers = 0
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
|
||||
logdir = os.path.join(logdir, now)
|
||||
|
||||
ckpt_output_dir = os.path.join(logdir, "checkpoints")
|
||||
cfg_output_dir = os.path.join(logdir, "configs")
|
||||
seed_everything(seed)
|
||||
model = get_diffusion_model(
|
||||
weights_location=weights_location, half_mode=False, for_training=True
|
||||
)._model
|
||||
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
|
||||
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "imaginairy.train.SetupCallback",
|
||||
"params": {
|
||||
"resume": False,
|
||||
"now": now,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckpt_output_dir,
|
||||
"cfgdir": cfg_output_dir,
|
||||
},
|
||||
},
|
||||
"image_logger": {
|
||||
"target": "imaginairy.train.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 10,
|
||||
"max_images": 1,
|
||||
"clamp": True,
|
||||
"increase_log_steps": False,
|
||||
"log_first_step": True,
|
||||
"log_all_val": True,
|
||||
"concept_label": concept_label,
|
||||
"log_images_kwargs": {
|
||||
"use_ema_scope": True,
|
||||
"inpaint": False,
|
||||
"plot_progressive_rows": False,
|
||||
"plot_diffusion_rows": False,
|
||||
"N": 1,
|
||||
"unconditional_guidance_scale:": 7.5,
|
||||
"unconditional_guidance_label": [""],
|
||||
"ddim_steps": 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "imaginairy.train.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
},
|
||||
},
|
||||
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
|
||||
}
|
||||
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckpt_output_dir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
"every_n_train_steps": 50,
|
||||
"save_top_k": -1,
|
||||
"monitor": None,
|
||||
},
|
||||
}
|
||||
|
||||
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
|
||||
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
||||
|
||||
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
|
||||
|
||||
dataset_config = {
|
||||
"concept_label": concept_label,
|
||||
"concept_images_dir": concept_images_dir,
|
||||
"class_label": class_label,
|
||||
"class_images_dir": class_images_dir,
|
||||
"image_transforms": [
|
||||
{
|
||||
"target": "torchvision.transforms.Resize",
|
||||
"params": {"size": 512, "interpolation": 3},
|
||||
},
|
||||
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
|
||||
],
|
||||
}
|
||||
|
||||
data_module_config = {
|
||||
"batch_size": batch_size,
|
||||
"num_workers": num_workers,
|
||||
"num_val_workers": num_val_workers,
|
||||
"train": {
|
||||
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
|
||||
"params": dataset_config,
|
||||
},
|
||||
}
|
||||
trainer = Trainer(
|
||||
benchmark=True,
|
||||
num_sanity_val_steps=0,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
strategy=DDPStrategy(),
|
||||
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
|
||||
gpus=1,
|
||||
default_root_dir=".",
|
||||
)
|
||||
trainer.logdir = logdir
|
||||
|
||||
data = DataModuleFromConfig(**data_module_config)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
|
||||
def melk(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
try:
|
||||
try:
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
raise
|
||||
finally:
|
||||
mod_logger.info(trainer.profiler.summary())
|
Loading…
Reference in New Issue
Block a user