feature: finetuning
- feature: finetuning your own image models - feature: image prep command. crops to face or other interesting parts of photo - fix: back-compat for hf_hub_download - feature: add prune-ckpt command - feature: allow specification of model config filepull/159/head
parent
4bc78b9be5
commit
5cc73f6087
Binary file not shown.
@ -0,0 +1,85 @@
|
||||
# Adding a concept to Stable Diffusion
|
||||
|
||||
You can use Imaginairy to teach the model a new concept (a person, thing, style, etc) using the `aimg train-concept`
|
||||
command.
|
||||
|
||||
## Requirements
|
||||
- Graphics card: 3090 or better
|
||||
- Linux
|
||||
- A working Imaginairy installation
|
||||
- a folder of images of the concept you want to teach the model
|
||||
|
||||
## Background
|
||||
|
||||
To train the model we show it a lot of images of the concept we want to teach it. The problem is the model can easily
|
||||
overfit to the images we show it. To prevent this we also show it images of the class of thing that is being trained.
|
||||
Imaginairy will generate the images needed for this before running the training job.
|
||||
|
||||
Provided a directory of concept images, a concept token, and a class token, this command will train the model
|
||||
to generate images of that concept.
|
||||
|
||||
|
||||
This happens in a 3-step process:
|
||||
|
||||
1. Cropping and resizing your training images. If --person is set we crop to include the face.
|
||||
2. Generating a set of class images to train on. This helps prevent overfitting.
|
||||
3. Training the model on the concept and class images.
|
||||
|
||||
The output of this command is a new model weights file that you can use with the --model option.
|
||||
|
||||
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Gather a set of images of the concept you want to train on. The images should show the subject from a variety of angles
|
||||
and in a variety of situations.
|
||||
2. Run `aimg train-concept` to train the model.
|
||||
|
||||
- Concept label: For a person, firstnamelastname should be fine.
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
|
||||
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
|
||||
or "woman".
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
|
||||
- CLass images will be generated for you if you do not provide them.
|
||||
|
||||
For example, if you were training on photos of a man named bill hamilton you could run the following:
|
||||
|
||||
```
|
||||
aimg train-concept \\
|
||||
--person \\
|
||||
--concept-label "photo of billhamilton man" \\
|
||||
--concept-images-dir ./images/billhamilton \\
|
||||
--class-label "photo of a man" \\
|
||||
--class-images-dir ./images/man
|
||||
```
|
||||
3. Stop training before it overfits.
|
||||
- The training script will output checkpoint ckpt files into the logs folder of wherever it is run from. You can also
|
||||
monitor generated images in the logs/images folder. They will be the ones named "sample"
|
||||
- I don't have great advice on when to stop training yet. I stopped mine at epoch 62 at it didn't seem quite good enough, at epoch 111 it
|
||||
produced my face correctly 50% of the time but also seemed overfit in some ways (always placing me in the same clothes or background as training photos).
|
||||
- You can monitor model training progress in Tensorboard. Run `tensorboard --logdir lightning_logs` and open the link it gives you in your browser.
|
||||
|
||||
4. Prune the model to bring the size from 11gb to ~4gb: `aimg prune-ckpt logs/2023-01-15T05-52-06/checkpoints/epoch\=000049.ckpt`. Copy it somewhere
|
||||
and give it a meaninful name.
|
||||
|
||||
## Using the new model
|
||||
You can reference the model like this in imaginairy:
|
||||
`imagine --model my-models/billhamilton-man-e111.ckpt`
|
||||
|
||||
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
|
||||
|
||||
|
||||
## Disclaimers
|
||||
|
||||
- The settings imaginairy uses to train the model are different than other software projects. As such you cannot follow
|
||||
advice you may read from other tutorials regarding learning rate, epochs, steps, batch size. They are not directly
|
||||
comparable. In laymans terms the "steps" are much bigger in imaginairy.
|
||||
- I consider this training feature experimental and don't currently plan to offer support for it. Any further work will
|
||||
be at my leisure. As a result I may close any reported issues related to this feature.
|
||||
- You can find a lot more relevant information here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
|
||||
|
||||
## Todo
|
||||
- figure out how to improve consistency of quality from trained model
|
||||
- train on the depth guided model instead of SD 1.5 since that will enable more consistent output
|
||||
- figure out metric to use for stopping training
|
||||
- possibly swap out and randomize backgrounds on training photos so over-fitting does not occur
|
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import face_restore_helper
|
||||
from imaginairy.roi_utils import resize_roi_coordinates, square_roi_coordinate
|
||||
|
||||
|
||||
def detect_faces(img):
|
||||
face_helper = face_restore_helper()
|
||||
face_helper.clean_all()
|
||||
|
||||
image = img.convert("RGB")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
# rotate to BGR
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
face_helper.read_image(np_img)
|
||||
|
||||
face_helper.get_face_landmarks_5(
|
||||
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||
)
|
||||
face_helper.align_warp_face()
|
||||
faceboxes = []
|
||||
|
||||
for x1, y1, x2, y2, scaling in face_helper.det_faces:
|
||||
# x1, y1, x2, y2 = x1 * scaling, y1 * scaling, x2 * scaling, y2 * scaling
|
||||
faceboxes.append((x1, y1, x2, y2))
|
||||
|
||||
return faceboxes
|
||||
|
||||
|
||||
def generate_face_crops(face_roi, max_width, max_height):
|
||||
"""Returns bounding boxes at various zoom levels for faces in the image."""
|
||||
|
||||
crops = []
|
||||
squared_roi = square_roi_coordinate(face_roi, max_width, max_height)
|
||||
|
||||
crops.append(resize_roi_coordinates(squared_roi, 1.1, max_width, max_height))
|
||||
# 1.6 generally enough to capture entire face
|
||||
base_expanded_roi = resize_roi_coordinates(squared_roi, 1.6, max_width, max_height)
|
||||
|
||||
crops.append(base_expanded_roi)
|
||||
current_width = base_expanded_roi[2] - base_expanded_roi[0]
|
||||
|
||||
# some zoomed out variations
|
||||
for n in range(2):
|
||||
factor = 1.25 + 0.4 * n
|
||||
|
||||
expanded_roi = resize_roi_coordinates(
|
||||
base_expanded_roi, factor, max_width, max_height, expand_up=False
|
||||
)
|
||||
new_width = expanded_roi[2] - expanded_roi[0]
|
||||
if new_width <= current_width * 1.1:
|
||||
# if the zoomed out size isn't suffienctly larger (because there is nowhere to zoom out to), stop
|
||||
break
|
||||
crops.append(expanded_roi)
|
||||
current_width = new_width
|
||||
return crops
|
@ -0,0 +1,133 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
t = (n - self.lr_warm_up_steps) / (
|
||||
self.lr_max_decay_steps - self.lr_warm_up_steps
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
||||
):
|
||||
assert (
|
||||
len(warm_up_steps)
|
||||
== len(f_min)
|
||||
== len(f_max)
|
||||
== len(f_start)
|
||||
== len(cycle_lengths)
|
||||
)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
return None
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
self.cycle_lengths[cycle] - n
|
||||
) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
@ -0,0 +1,103 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def square_roi_coordinate(roi, max_width, max_height, best_effort=False):
|
||||
"""Given a region of interest, returns a square region of interest."""
|
||||
x1, y1, x2, y2 = roi
|
||||
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
|
||||
roi_width = x2 - x1
|
||||
roi_height = y2 - y1
|
||||
if roi_width < roi_height:
|
||||
diff = roi_height - roi_width
|
||||
x1 -= int(round(diff / 2))
|
||||
x2 += roi_height - (x2 - x1)
|
||||
elif roi_height < roi_width:
|
||||
diff = roi_width - roi_height
|
||||
y1 -= int(round(diff / 2))
|
||||
y2 += roi_width - (y2 - y1)
|
||||
|
||||
x1, y1, x2, y2 = move_roi_into_bounds(
|
||||
(x1, y1, x2, y2), max_width, max_height, force=best_effort
|
||||
)
|
||||
width = x2 - x1
|
||||
height = y2 - y1
|
||||
if not best_effort and width != height:
|
||||
raise RuntimeError(f"ROI is not square: {width}x{height}")
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def resize_roi_coordinates(
|
||||
roi, expansion_factor, max_width, max_height, expand_up=True
|
||||
):
|
||||
"""
|
||||
Resize a region of interest while staying within the bounds.
|
||||
|
||||
setting expand_up to False will prevent the ROI from expanding upwards, which is useful when
|
||||
expanding something like a face roi to capture more of the person instead of empty space above them.
|
||||
|
||||
"""
|
||||
x1, y1, x2, y2 = roi
|
||||
side_length_x = x2 - x1
|
||||
side_length_y = y2 - y1
|
||||
|
||||
max_expansion_factor = min(max_height / side_length_y, max_width / side_length_x)
|
||||
expansion_factor = min(expansion_factor, max_expansion_factor)
|
||||
|
||||
expansion_x = int(round(side_length_x * expansion_factor - side_length_x))
|
||||
expansion_x_a = int(round(expansion_x / 2))
|
||||
expansion_x_b = expansion_x - expansion_x_a
|
||||
x1 -= expansion_x_a
|
||||
x2 += expansion_x_b
|
||||
|
||||
expansion_y = int(round(side_length_y * expansion_factor - side_length_y))
|
||||
if expand_up:
|
||||
expansion_y_a = int(round(expansion_y / 2))
|
||||
expansion_y_b = expansion_y - expansion_y_a
|
||||
y1 -= expansion_y_a
|
||||
y2 += expansion_y_b
|
||||
else:
|
||||
y2 += expansion_y
|
||||
|
||||
x1, y1, x2, y2 = move_roi_into_bounds((x1, y1, x2, y2), max_width, max_height)
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
class RoiNotInBoundsError(ValueError):
|
||||
"""Error raised when a ROI is not within the bounds of the image."""
|
||||
|
||||
|
||||
def move_roi_into_bounds(roi, max_width, max_height, force=False):
|
||||
"""Move a region of interest into the bounds of the image."""
|
||||
x1, y1, x2, y2 = roi
|
||||
|
||||
# move the ROI within the image boundaries
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 = 0
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 = 0
|
||||
if x2 > max_width:
|
||||
x1 -= x2 - max_width
|
||||
x2 = max_width
|
||||
if y2 > max_height:
|
||||
y1 -= y2 - max_height
|
||||
y2 = max_height
|
||||
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
|
||||
# Force ROI to fit within image boundaries (sacrificing size and aspect ratio of ROI)
|
||||
if force:
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(max_width, x2)
|
||||
y2 = min(max_height, y2)
|
||||
if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height:
|
||||
roi_width = x2 - x1
|
||||
roi_height = y2 - y1
|
||||
raise RoiNotInBoundsError(
|
||||
f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
|
||||
)
|
||||
|
||||
return x1, y1, x2, y2
|
@ -0,0 +1,520 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.model_manager import get_diffusion_model
|
||||
from imaginairy.training_tools.single_concept import SingleConceptDataset
|
||||
from imaginairy.utils import get_device, instantiate_from_config
|
||||
|
||||
mod_logger = logging.getLogger(__name__)
|
||||
|
||||
referenced_by_string = [LearningRateMonitor]
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, SingleConceptDataset):
|
||||
# split_size = dataset.num_records // worker_info.num_workers
|
||||
# reset num_records to the true number to retain reliable length information
|
||||
# dataset.sample_ids = dataset.valid_ids[
|
||||
# worker_id * split_size : (worker_id + 1) * split_size
|
||||
# ]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False,
|
||||
num_val_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = {}
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
if num_val_workers is None:
|
||||
self.num_val_workers = self.num_workers
|
||||
else:
|
||||
self.num_val_workers = num_val_workers
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(
|
||||
self._val_dataloader, shuffle=shuffle_val_dataloader
|
||||
)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(
|
||||
self._test_dataloader, shuffle=shuffle_test_loader
|
||||
)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
self.wrap = wrap
|
||||
self.datasets = None
|
||||
|
||||
def prepare_data(self):
|
||||
for data_cfg in self.dataset_configs.values():
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = {
|
||||
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
|
||||
}
|
||||
if self.wrap:
|
||||
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
return DataLoader(
|
||||
self.datasets["train"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["validation"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_val_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
is_iterable_dataset = False
|
||||
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoader(
|
||||
self.datasets["test"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["predict"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
resume,
|
||||
now,
|
||||
logdir,
|
||||
ckptdir,
|
||||
cfgdir,
|
||||
):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.now = now
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Stopping execution and saving final checkpoint.")
|
||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
try:
|
||||
os.rename(self.logdir, dst)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
batch_frequency,
|
||||
max_images,
|
||||
clamp=True,
|
||||
increase_log_steps=True,
|
||||
rescale=True,
|
||||
disabled=False,
|
||||
log_on_batch_idx=False,
|
||||
log_first_step=False,
|
||||
log_images_kwargs=None,
|
||||
log_all_val=False,
|
||||
concept_label=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {}
|
||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
self.disabled = disabled
|
||||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
self.log_all_val = log_all_val
|
||||
self.concept_label = concept_label
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "logs", "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = (
|
||||
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
|
||||
)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
# always generate the concept label
|
||||
batch["txt"][0] = self.concept_label
|
||||
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if self.log_all_val and split == "val":
|
||||
should_log = True
|
||||
else:
|
||||
should_log = self.check_frequency(check_idx)
|
||||
if (
|
||||
should_log
|
||||
and (batch_idx % self.batch_freq == 0)
|
||||
and hasattr(pl_module, "log_images")
|
||||
and callable(pl_module.log_images)
|
||||
and self.max_images > 0
|
||||
):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(
|
||||
batch, split=split, **self.log_images_kwargs
|
||||
)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
||||
|
||||
self.log_local(
|
||||
pl_module.logger.save_dir,
|
||||
split,
|
||||
images,
|
||||
pl_module.global_step,
|
||||
pl_module.current_epoch,
|
||||
batch_idx,
|
||||
)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(
|
||||
logger, lambda *args, **kwargs: None
|
||||
)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if (check_idx % self.batch_freq) == 0 and (
|
||||
check_idx > 0 or self.log_first_step
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, "calibrate_grad_norm"):
|
||||
if (
|
||||
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
|
||||
) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
self.start_time = time.time() # noqa
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module): # noqa
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = (
|
||||
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
|
||||
/ 2**20
|
||||
)
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def train_diffusion_model(
|
||||
concept_label,
|
||||
concept_images_dir,
|
||||
class_label,
|
||||
class_images_dir,
|
||||
weights_location=config.DEFAULT_MODEL,
|
||||
logdir="logs",
|
||||
learning_rate=1e-6,
|
||||
accumulate_grad_batches=32,
|
||||
resume=None,
|
||||
):
|
||||
batch_size = 1
|
||||
seed = 23
|
||||
num_workers = 1
|
||||
num_val_workers = 0
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
logdir = os.path.join(logdir, now)
|
||||
|
||||
ckpt_output_dir = os.path.join(logdir, "checkpoints")
|
||||
cfg_output_dir = os.path.join(logdir, "configs")
|
||||
seed_everything(seed)
|
||||
model = get_diffusion_model( # noqa
|
||||
weights_location=weights_location, half_mode=False, for_training=True
|
||||
)._model
|
||||
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
|
||||
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "imaginairy.train.SetupCallback",
|
||||
"params": {
|
||||
"resume": False,
|
||||
"now": now,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckpt_output_dir,
|
||||
"cfgdir": cfg_output_dir,
|
||||
},
|
||||
},
|
||||
"image_logger": {
|
||||
"target": "imaginairy.train.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 10,
|
||||
"max_images": 1,
|
||||
"clamp": True,
|
||||
"increase_log_steps": False,
|
||||
"log_first_step": True,
|
||||
"log_all_val": True,
|
||||
"concept_label": concept_label,
|
||||
"log_images_kwargs": {
|
||||
"use_ema_scope": True,
|
||||
"inpaint": False,
|
||||
"plot_progressive_rows": False,
|
||||
"plot_diffusion_rows": False,
|
||||
"N": 1,
|
||||
"unconditional_guidance_scale:": 7.5,
|
||||
"unconditional_guidance_label": [""],
|
||||
"ddim_steps": 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "imaginairy.train.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
},
|
||||
},
|
||||
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
|
||||
}
|
||||
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckpt_output_dir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
"every_n_train_steps": 50,
|
||||
"save_top_k": -1,
|
||||
"monitor": None,
|
||||
},
|
||||
}
|
||||
|
||||
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
|
||||
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
||||
|
||||
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
|
||||
|
||||
dataset_config = {
|
||||
"concept_label": concept_label,
|
||||
"concept_images_dir": concept_images_dir,
|
||||
"class_label": class_label,
|
||||
"class_images_dir": class_images_dir,
|
||||
"image_transforms": [
|
||||
{
|
||||
"target": "torchvision.transforms.Resize",
|
||||
"params": {"size": 512, "interpolation": 3},
|
||||
},
|
||||
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
|
||||
],
|
||||
}
|
||||
|
||||
data_module_config = {
|
||||
"batch_size": batch_size,
|
||||
"num_workers": num_workers,
|
||||
"num_val_workers": num_val_workers,
|
||||
"train": {
|
||||
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
|
||||
"params": dataset_config,
|
||||
},
|
||||
}
|
||||
trainer = Trainer(
|
||||
benchmark=True,
|
||||
num_sanity_val_steps=0,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
strategy=DDPStrategy(),
|
||||
callbacks=[
|
||||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa
|
||||
],
|
||||
gpus=1,
|
||||
default_root_dir=".",
|
||||
)
|
||||
trainer.logdir = logdir
|
||||
|
||||
data = DataModuleFromConfig(**data_module_config)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
|
||||
def melk(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
try:
|
||||
|
||||
try:
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
raise
|
||||
finally:
|
||||
mod_logger.info(trainer.profiler.summary())
|
@ -0,0 +1,155 @@
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.vendored.smart_crop import SmartCrop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_filenames(folder):
|
||||
filenames = []
|
||||
for filename in os.listdir(folder):
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
|
||||
continue
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def prep_images(
|
||||
images_dir, is_person=False, output_folder_name="prepped-images", target_size=512
|
||||
):
|
||||
"""
|
||||
Crops and resizes a directory of images in preparation for training.
|
||||
|
||||
If is_person=True, it will detect the face and produces several crops at different zoom levels. For crops that
|
||||
are too small, it will use the face restoration model to enhance the faces.
|
||||
|
||||
For non-person images, it will use the smartcrop algorithm to crop the image to the most interesting part. If the
|
||||
input image is too small it will be upscaled.
|
||||
|
||||
Prep will go a lot faster if all the images are big enough to not require upscaling.
|
||||
|
||||
"""
|
||||
output_folder = os.path.join(images_dir, output_folder_name)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
logger.info(f"Prepping images in {images_dir} to {output_folder}")
|
||||
image_filenames = get_image_filenames(images_dir)
|
||||
pbar = tqdm(image_filenames)
|
||||
for filename in pbar:
|
||||
pbar.set_description(filename)
|
||||
|
||||
input_path = os.path.join(images_dir, filename)
|
||||
img = LazyLoadingImage(filepath=input_path).convert("RGB")
|
||||
if is_person:
|
||||
face_rois = detect_faces(img)
|
||||
if len(face_rois) == 0:
|
||||
logger.info(f"No faces detected in image {filename}, skipping")
|
||||
continue
|
||||
if len(face_rois) > 1:
|
||||
logger.info(f"Multiple faces detected in image {filename}, skipping")
|
||||
continue
|
||||
face_roi = face_rois[0]
|
||||
face_roi_crops = generate_face_crops(
|
||||
face_roi, max_width=img.width, max_height=img.height
|
||||
)
|
||||
for n, face_roi_crop in enumerate(face_roi_crops):
|
||||
cropped_output_path = os.path.join(
|
||||
output_folder, f"{filename}_[alt-{n:02d}].jpg"
|
||||
)
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
x1, y1, x2, y2 = face_roi_crop
|
||||
crop_width = x2 - x1
|
||||
crop_height = y2 - y1
|
||||
if crop_width != crop_height:
|
||||
logger.info(
|
||||
f"Face ROI crop for {filename} {crop_width}x{crop_height} is not square, skipping"
|
||||
)
|
||||
continue
|
||||
cropped_img = img.crop(face_roi_crop)
|
||||
|
||||
if crop_width < target_size:
|
||||
logger.info(f"Upscaling {filename} {face_roi_crop}")
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img = enhance_faces(cropped_img, fidelity=1)
|
||||
else:
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img.save(cropped_output_path, quality=95)
|
||||
else:
|
||||
# scale image so that largest dimension is target_size
|
||||
n = 0
|
||||
cropped_output_path = os.path.join(output_folder, f"{filename}_{n}.jpg")
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
if img.width < target_size or img.height < target_size:
|
||||
# upscale the image if it's too small
|
||||
logger.info(f"Upscaling {filename}")
|
||||
img = upscale_image(img)
|
||||
|
||||
if img.width > img.height:
|
||||
scale_factor = target_size / img.height
|
||||
else:
|
||||
scale_factor = target_size / img.width
|
||||
|
||||
# downscale so shortest side is target_size
|
||||
new_width = int(round(img.width * scale_factor))
|
||||
new_height = int(round(img.height * scale_factor))
|
||||
img = img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
result = SmartCrop().crop(img, width=target_size, height=target_size)
|
||||
|
||||
box = (
|
||||
result["top_crop"]["x"],
|
||||
result["top_crop"]["y"],
|
||||
result["top_crop"]["width"] + result["top_crop"]["x"],
|
||||
result["top_crop"]["height"] + result["top_crop"]["y"],
|
||||
)
|
||||
|
||||
cropped_image = img.crop(box)
|
||||
cropped_image.save(cropped_output_path, quality=95)
|
||||
logger.info(f"Image Prep complete. Review output at {output_folder}")
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "-", prompt)[:130]
|
||||
|
||||
|
||||
def create_class_images(class_description, output_folder, num_images=200):
|
||||
"""
|
||||
Generate images of class_description.
|
||||
"""
|
||||
existing_images = get_image_filenames(output_folder)
|
||||
existing_image_count = len(existing_images)
|
||||
class_slug = prompt_normalized(class_description)
|
||||
|
||||
while existing_image_count < num_images:
|
||||
prompt = ImaginePrompt(class_description, steps=20)
|
||||
result = list(imagine([prompt]))[0]
|
||||
if result.is_nsfw:
|
||||
continue
|
||||
dest = os.path.join(
|
||||
output_folder, f"{existing_image_count:03d}_{class_slug}.jpg"
|
||||
)
|
||||
result.save(dest)
|
||||
existing_image_count += 1
|
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prune_diffusion_ckpt(ckpt_path, dst_path=None):
|
||||
if dst_path is None:
|
||||
dst_path = f"{os.path.splitext(ckpt_path)[0]}-pruned.ckpt"
|
||||
|
||||
data = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
new_data = prune_model_data(data)
|
||||
|
||||
torch.save(new_data, dst_path)
|
||||
|
||||
size_initial = os.path.getsize(ckpt_path)
|
||||
newsize = os.path.getsize(dst_path)
|
||||
msg = (
|
||||
f"New ckpt size: {newsize * 1e-9:.2f} GB. "
|
||||
f"Saved {(size_initial - newsize) * 1e-9:.2f} GB by removing optimizer states"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def prune_model_data(data):
|
||||
skip_keys = {"optimizer_states"}
|
||||
new_data = {k: v for k, v in data.items() if k not in skip_keys}
|
||||
return new_data
|
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from einops import rearrange
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable.
|
||||
"""
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _rearrange(x):
|
||||
return rearrange(x * 2.0 - 1.0, "c h w -> h w c")
|
||||
|
||||
|
||||
class SingleConceptDataset(Dataset):
|
||||
"""
|
||||
Dataset for finetuning a model on a single concept.
|
||||
|
||||
Similar to "dreambooth"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
concept_label,
|
||||
class_label,
|
||||
concept_images_dir,
|
||||
class_images_dir,
|
||||
image_transforms=None,
|
||||
):
|
||||
self.concept_label = concept_label
|
||||
self.class_label = class_label
|
||||
self.concept_images_dir = concept_images_dir
|
||||
self.class_images_dir = class_images_dir
|
||||
|
||||
if isinstance(image_transforms, (ListConfig, list)):
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(_rearrange),
|
||||
]
|
||||
)
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
self._concept_image_filename_groups = None
|
||||
self._class_image_filenames = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.concept_image_filename_groups) * 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx % 2:
|
||||
img_group = self._concept_image_filename_groups[int(idx / 2)]
|
||||
img_filename = random.choice(img_group)
|
||||
img_path = os.path.join(self.concept_images_dir, img_filename)
|
||||
|
||||
txt = self.concept_label
|
||||
else:
|
||||
img_path = os.path.join(
|
||||
self.class_images_dir, self.class_image_filenames[int(idx / 2)]
|
||||
)
|
||||
txt = self.class_label
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Could not read image {img_path}") from e
|
||||
image = self.image_transforms(image)
|
||||
data = {"image": image, "txt": txt}
|
||||
return data
|
||||
|
||||
@property
|
||||
def concept_image_filename_groups(self):
|
||||
if self._concept_image_filename_groups is None:
|
||||
self._concept_image_filename_groups = _load_image_filenames_and_alts(
|
||||
self.concept_images_dir
|
||||
)
|
||||
return self._concept_image_filename_groups
|
||||
|
||||
@property
|
||||
def class_image_filenames(self):
|
||||
if self._class_image_filenames is None:
|
||||
self._class_image_filenames = _load_image_filenames(self.class_images_dir)
|
||||
return self._class_image_filenames
|
||||
|
||||
@property
|
||||
def num_records(self):
|
||||
return len(self)
|
||||
|
||||
|
||||
def _load_image_filenames_and_alts(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
"""Loads images into groups (filenames tagged with `[alt-{n:02d}]` are grouped together)."""
|
||||
image_filenames = _load_image_filenames(img_dir, image_extensions)
|
||||
grouped_img_filenames = defaultdict(list)
|
||||
for filename in image_filenames:
|
||||
base_filename = re.sub(r"\[alt-\d*\]", "", filename)
|
||||
grouped_img_filenames[base_filename].append(filename)
|
||||
return list(grouped_img_filenames.values())
|
||||
|
||||
|
||||
def _load_image_filenames(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
image_filenames = []
|
||||
for filename in os.listdir(img_dir):
|
||||
if filename.lower().endswith(image_extensions) and not filename.startswith("."):
|
||||
image_filenames.append(filename)
|
||||
random.shuffle(image_filenames)
|
||||
return image_filenames
|
@ -0,0 +1,375 @@
|
||||
"""
|
||||
Crops to the most interesting part of the image.
|
||||
|
||||
MIT License from https://github.com/smartcrop/smartcrop.py/commit/f5377045035abc7ae79d8d9ad40bbc7fce0f6ad7
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL.ImageFilter import Kernel
|
||||
|
||||
|
||||
def saturation(image):
|
||||
r, g, b = image.split()
|
||||
r, g, b = np.array(r), np.array(g), np.array(b)
|
||||
r, g, b = r.astype(float), g.astype(float), b.astype(float)
|
||||
maximum = np.maximum(np.maximum(r, g), b) # [0; 255]
|
||||
minimum = np.minimum(np.minimum(r, g), b) # [0; 255]
|
||||
s = (maximum + minimum) / 255 # [0.0; 1.0]
|
||||
d = (maximum - minimum) / 255 # [0.0; 1.0]
|
||||
d[maximum == minimum] = 0 # if maximum == minimum:
|
||||
s[maximum == minimum] = 1 # -> saturation = 0 / 1 = 0
|
||||
mask = s > 1
|
||||
s[mask] = 2 - d[mask]
|
||||
return d / s # [0.0; 1.0]
|
||||
|
||||
|
||||
def thirds(x):
|
||||
"""gets value in the range of [0, 1] where 0 is the center of the pictures
|
||||
returns weight of rule of thirds [0, 1]."""
|
||||
x = ((x + 2 / 3) % 2 * 0.5 - 0.5) * 16
|
||||
return max(1 - x * x, 0)
|
||||
|
||||
|
||||
class SmartCrop:
|
||||
DEFAULT_SKIN_COLOR = [0.78, 0.57, 0.44]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail_weight=0.2,
|
||||
edge_radius=0.4,
|
||||
edge_weight=-20,
|
||||
outside_importance=-0.5,
|
||||
rule_of_thirds=True,
|
||||
saturation_bias=0.2,
|
||||
saturation_brightness_max=0.9,
|
||||
saturation_brightness_min=0.05,
|
||||
saturation_threshold=0.4,
|
||||
saturation_weight=0.3,
|
||||
score_down_sample=8,
|
||||
skin_bias=0.01,
|
||||
skin_brightness_max=1,
|
||||
skin_brightness_min=0.2,
|
||||
skin_color=None,
|
||||
skin_threshold=0.8,
|
||||
skin_weight=1.8,
|
||||
):
|
||||
self.detail_weight = detail_weight
|
||||
self.edge_radius = edge_radius
|
||||
self.edge_weight = edge_weight
|
||||
self.outside_importance = outside_importance
|
||||
self.rule_of_thirds = rule_of_thirds
|
||||
self.saturation_bias = saturation_bias
|
||||
self.saturation_brightness_max = saturation_brightness_max
|
||||
self.saturation_brightness_min = saturation_brightness_min
|
||||
self.saturation_threshold = saturation_threshold
|
||||
self.saturation_weight = saturation_weight
|
||||
self.score_down_sample = score_down_sample
|
||||
self.skin_bias = skin_bias
|
||||
self.skin_brightness_max = skin_brightness_max
|
||||
self.skin_brightness_min = skin_brightness_min
|
||||
self.skin_color = skin_color or self.DEFAULT_SKIN_COLOR
|
||||
self.skin_threshold = skin_threshold
|
||||
self.skin_weight = skin_weight
|
||||
|
||||
def analyse(
|
||||
self,
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
"""
|
||||
Analyze image and return some suggestions of crops (coordinates).
|
||||
This implementation / algorithm is really slow for large images.
|
||||
Use `crop()` which is pre-scaling the image before analyzing it.
|
||||
"""
|
||||
cie_image = image.convert("L", (0.2126, 0.7152, 0.0722, 0))
|
||||
cie_array = np.array(cie_image) # [0; 255]
|
||||
|
||||
# R=skin G=edge B=saturation
|
||||
edge_image = self.detect_edge(cie_image)
|
||||
skin_image = self.detect_skin(cie_array, image)
|
||||
saturation_image = self.detect_saturation(cie_array, image)
|
||||
analyse_image = Image.merge("RGB", [skin_image, edge_image, saturation_image])
|
||||
|
||||
del edge_image
|
||||
del skin_image
|
||||
del saturation_image
|
||||
|
||||
score_image = analyse_image.copy()
|
||||
score_image.thumbnail(
|
||||
(
|
||||
int(math.ceil(image.size[0] / self.score_down_sample)),
|
||||
int(math.ceil(image.size[1] / self.score_down_sample)),
|
||||
),
|
||||
Image.ANTIALIAS,
|
||||
)
|
||||
|
||||
top_crop = None
|
||||
top_score = -sys.maxsize
|
||||
|
||||
crops = self.crops(
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=max_scale,
|
||||
min_scale=min_scale,
|
||||
scale_step=scale_step,
|
||||
step=step,
|
||||
)
|
||||
|
||||
for crop in crops:
|
||||
crop["score"] = self.score(score_image, crop)
|
||||
if crop["score"]["total"] > top_score:
|
||||
top_crop = crop
|
||||
top_score = crop["score"]["total"]
|
||||
|
||||
return {"analyse_image": analyse_image, "crops": crops, "top_crop": top_crop}
|
||||
|
||||
def crop(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
prescale=True,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
"""Not yet fully cleaned from https://github.com/hhatto/smartcrop.py."""
|
||||
scale = min(image.size[0] / width, image.size[1] / height)
|
||||
crop_width = int(math.floor(width * scale))
|
||||
crop_height = int(math.floor(height * scale))
|
||||
# img = 100x100, width = 95x95, scale = 100/95, 1/scale > min
|
||||
# don't set minscale smaller than 1/scale
|
||||
# -> don't pick crops that need upscaling
|
||||
min_scale = min(max_scale, max(1 / scale, min_scale))
|
||||
|
||||
prescale_size = 1
|
||||
if prescale:
|
||||
prescale_size = 1 / scale / min_scale
|
||||
if prescale_size < 1:
|
||||
image = image.copy()
|
||||
image.thumbnail(
|
||||
(
|
||||
int(image.size[0] * prescale_size),
|
||||
int(image.size[1] * prescale_size),
|
||||
),
|
||||
Image.ANTIALIAS,
|
||||
)
|
||||
crop_width = int(math.floor(crop_width * prescale_size))
|
||||
crop_height = int(math.floor(crop_height * prescale_size))
|
||||
else:
|
||||
prescale_size = 1
|
||||
|
||||
result = self.analyse(
|
||||
image,
|
||||
crop_width=crop_width,
|
||||
crop_height=crop_height,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
scale_step=scale_step,
|
||||
step=step,
|
||||
)
|
||||
|
||||
for i in range(len(result["crops"])):
|
||||
crop = result["crops"][i]
|
||||
crop["x"] = int(math.floor(crop["x"] / prescale_size))
|
||||
crop["y"] = int(math.floor(crop["y"] / prescale_size))
|
||||
crop["width"] = int(math.floor(crop["width"] / prescale_size))
|
||||
crop["height"] = int(math.floor(crop["height"] / prescale_size))
|
||||
result["crops"][i] = crop
|
||||
return result
|
||||
|
||||
def crops(
|
||||
self,
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
image_width, image_height = image.size
|
||||
crops = []
|
||||
for scale in (
|
||||
i / 100
|
||||
for i in range(
|
||||
int(max_scale * 100),
|
||||
int((min_scale - scale_step) * 100),
|
||||
-int(scale_step * 100),
|
||||
)
|
||||
):
|
||||
for y in range(0, image_height, step):
|
||||
if not (y + crop_height * scale <= image_height):
|
||||
break
|
||||
for x in range(0, image_width, step):
|
||||
if not (x + crop_width * scale <= image_width):
|
||||
break
|
||||
crops.append(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": crop_width * scale,
|
||||
"height": crop_height * scale,
|
||||
}
|
||||
)
|
||||
if not crops:
|
||||
raise ValueError(locals())
|
||||
return crops
|
||||
|
||||
def debug_crop(self, analyse_image, crop):
|
||||
debug_image = analyse_image.copy()
|
||||
debug_pixels = debug_image.getdata()
|
||||
debug_crop_image = Image.new(
|
||||
"RGBA",
|
||||
(int(math.floor(crop["width"])), int(math.floor(crop["height"]))),
|
||||
(255, 0, 0, 25),
|
||||
)
|
||||
ImageDraw.Draw(debug_crop_image).rectangle(
|
||||
((0, 0), (crop["width"], crop["height"])), outline=(255, 0, 0)
|
||||
)
|
||||
|
||||
for y in range(analyse_image.size[1]): # height
|
||||
for x in range(analyse_image.size[0]): # width
|
||||
p = y * analyse_image.size[0] + x
|
||||
importance = self.importance(crop, x, y)
|
||||
if importance > 0:
|
||||
debug_pixels.putpixel(
|
||||
(x, y),
|
||||
(
|
||||
debug_pixels[p][0],
|
||||
int(debug_pixels[p][1] + importance * 32),
|
||||
debug_pixels[p][2],
|
||||
),
|
||||
)
|
||||
elif importance < 0:
|
||||
debug_pixels.putpixel(
|
||||
(x, y),
|
||||
(
|
||||
int(debug_pixels[p][0] + importance * -64),
|
||||
debug_pixels[p][1],
|
||||
debug_pixels[p][2],
|
||||
),
|
||||
)
|
||||
debug_image.paste(
|
||||
debug_crop_image, (crop["x"], crop["y"]), debug_crop_image.split()[3]
|
||||
)
|
||||
return debug_image
|
||||
|
||||
def detect_edge(self, cie_image):
|
||||
return cie_image.filter(Kernel((3, 3), (0, -1, 0, -1, 4, -1, 0, -1, 0), 1, 1))
|
||||
|
||||
def detect_saturation(self, cie_array, source_image):
|
||||
threshold = self.saturation_threshold
|
||||
saturation_data = saturation(source_image)
|
||||
mask = (
|
||||
(saturation_data > threshold)
|
||||
& (cie_array >= self.saturation_brightness_min * 255)
|
||||
& (cie_array <= self.saturation_brightness_max * 255)
|
||||
)
|
||||
|
||||
saturation_data[~mask] = 0
|
||||
saturation_data[mask] = (saturation_data[mask] - threshold) * (
|
||||
255 / (1 - threshold)
|
||||
)
|
||||
|
||||
return Image.fromarray(saturation_data.astype("uint8"))
|
||||
|
||||
def detect_skin(self, cie_array, source_image):
|
||||
r, g, b = source_image.split()
|
||||
r, g, b = np.array(r), np.array(g), np.array(b)
|
||||
r, g, b = r.astype(float), g.astype(float), b.astype(float)
|
||||
rd = np.ones_like(r) * -self.skin_color[0]
|
||||
gd = np.ones_like(g) * -self.skin_color[1]
|
||||
bd = np.ones_like(b) * -self.skin_color[2]
|
||||
|
||||
mag = np.sqrt(r * r + g * g + b * b)
|
||||
mask = ~(abs(mag) < 1e-6)
|
||||
rd[mask] = r[mask] / mag[mask] - self.skin_color[0]
|
||||
gd[mask] = g[mask] / mag[mask] - self.skin_color[1]
|
||||
bd[mask] = b[mask] / mag[mask] - self.skin_color[2]
|
||||
|
||||
skin = 1 - np.sqrt(rd * rd + gd * gd + bd * bd)
|
||||
mask = (
|
||||
(skin > self.skin_threshold)
|
||||
& (cie_array >= self.skin_brightness_min * 255)
|
||||
& (cie_array <= self.skin_brightness_max * 255)
|
||||
)
|
||||
|
||||
skin_data = (skin - self.skin_threshold) * (255 / (1 - self.skin_threshold))
|
||||
skin_data[~mask] = 0
|
||||
|
||||
return Image.fromarray(skin_data.astype("uint8"))
|
||||
|
||||
def importance(self, crop, x, y):
|
||||
if (
|
||||
crop["x"] > x
|
||||
or x >= crop["x"] + crop["width"]
|
||||
or crop["y"] > y
|
||||
or y >= crop["y"] + crop["height"]
|
||||
):
|
||||
return self.outside_importance
|
||||
|
||||
x = (x - crop["x"]) / crop["width"]
|
||||
y = (y - crop["y"]) / crop["height"]
|
||||
px, py = abs(0.5 - x) * 2, abs(0.5 - y) * 2
|
||||
|
||||
# distance from edge
|
||||
dx = max(px - 1 + self.edge_radius, 0)
|
||||
dy = max(py - 1 + self.edge_radius, 0)
|
||||
d = (dx * dx + dy * dy) * self.edge_weight
|
||||
s = 1.41 - math.sqrt(px * px + py * py)
|
||||
|
||||
if self.rule_of_thirds:
|
||||
s += (max(0, s + d + 0.5) * 1.2) * (thirds(px) + thirds(py))
|
||||
|
||||
return s + d
|
||||
|
||||
def score(self, target_image, crop):
|
||||
score = {
|
||||
"detail": 0,
|
||||
"saturation": 0,
|
||||
"skin": 0,
|
||||
"total": 0,
|
||||
}
|
||||
target_data = target_image.getdata()
|
||||
target_width, target_height = target_image.size
|
||||
|
||||
down_sample = self.score_down_sample
|
||||
inv_down_sample = 1 / down_sample
|
||||
target_width_down_sample = target_width * down_sample
|
||||
target_height_down_sample = target_height * down_sample
|
||||
|
||||
for y in range(0, target_height_down_sample, down_sample):
|
||||
for x in range(0, target_width_down_sample, down_sample):
|
||||
p = int(
|
||||
math.floor(y * inv_down_sample) * target_width
|
||||
+ math.floor(x * inv_down_sample)
|
||||
)
|
||||
importance = self.importance(crop, x, y)
|
||||
detail = target_data[p][1] / 255
|
||||
score["skin"] += (
|
||||
target_data[p][0] / 255 * (detail + self.skin_bias) * importance
|
||||
)
|
||||
score["detail"] += detail * importance
|
||||
score["saturation"] += (
|
||||
target_data[p][2]
|
||||
/ 255
|
||||
* (detail + self.saturation_bias)
|
||||
* importance
|
||||
)
|
||||
score["total"] = (
|
||||
score["detail"] * self.detail_weight
|
||||
+ score["skin"] * self.skin_weight
|
||||
+ score["saturation"] * self.saturation_weight
|
||||
) / (crop["width"] * crop["height"])
|
||||
return score
|
@ -0,0 +1,12 @@
|
||||
import logging
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.enhancers.facecrop import generate_face_crops
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_facecrop():
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
|
||||
generate_face_crops((50, 50, 150, 150), max_width=img.width, max_height=img.height)
|
@ -0,0 +1,69 @@
|
||||
import itertools
|
||||
import random
|
||||
|
||||
from imaginairy.roi_utils import (
|
||||
RoiNotInBoundsError,
|
||||
resize_roi_coordinates,
|
||||
square_roi_coordinate,
|
||||
)
|
||||
|
||||
|
||||
def test_square_roi_coordinate():
|
||||
img_sizes = (10, 100, 200, 511, 513, 1024)
|
||||
# iterate through all permutations of image sizes using itertools.product
|
||||
for img_width, img_height in itertools.product(img_sizes, img_sizes):
|
||||
# randomly generate a region of interest
|
||||
for _ in range(100):
|
||||
x1 = random.randint(0, img_width)
|
||||
y1 = random.randint(0, img_height)
|
||||
x2 = random.randint(x1, img_width)
|
||||
y2 = random.randint(y1, img_height)
|
||||
roi = x1, y1, x2, y2
|
||||
try:
|
||||
x1, y1, x2, y2 = square_roi_coordinate(roi, img_width, img_height)
|
||||
except RoiNotInBoundsError:
|
||||
continue
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
||||
|
||||
|
||||
# resize_roi_coordinates
|
||||
|
||||
|
||||
def test_square_resize_roi_coordinates():
|
||||
img_sizes = (10, 100, 200, 403, 511, 513, 604, 1024)
|
||||
# iterate through all permutations of image sizes using itertools.product
|
||||
img_sizes = list(itertools.product(img_sizes, img_sizes))
|
||||
|
||||
for img_width, img_height in img_sizes:
|
||||
# randomly generate a region of interest
|
||||
rois = []
|
||||
for _ in range(100):
|
||||
x1 = random.randint(0 + 1, img_width - 1)
|
||||
y1 = random.randint(0 + 1, img_height - 1)
|
||||
x2 = random.randint(x1 + 1, img_width)
|
||||
y2 = random.randint(y1 + 1, img_height)
|
||||
roi = x1, y1, x2, y2
|
||||
rois.append(roi)
|
||||
rois.append((392, 85, 695, 389))
|
||||
for roi in rois:
|
||||
try:
|
||||
squared_roi = square_roi_coordinate(roi, img_width, img_height)
|
||||
except RoiNotInBoundsError:
|
||||
continue
|
||||
for n in range(10):
|
||||
factor = 1.25 + 0.3 * n
|
||||
x1, y1, x2, y2 = resize_roi_coordinates(
|
||||
squared_roi, factor, img_width, img_height
|
||||
)
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
||||
|
||||
x1, y1, x2, y2 = resize_roi_coordinates(
|
||||
squared_roi, factor, img_width, img_height, expand_up=False
|
||||
)
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
Loading…
Reference in New Issue