feature: remove training feature

pull/413/head
Bryce 10 months ago committed by Bryce Drennan
parent ef0f44646e
commit db85f0898a

@ -88,6 +88,7 @@ cutting edge features (SDXL, image prompts, etc) which will be added to imaginai
- deprecated: support for python 3.8, 3.9
- deprecated: support for torch 1.13
- deprecated: support for Stable Diffusion versions 1.4, 2.0, and 2.1
- deprecated: image training
- broken: most samplers, tile/details controlnet, and model memory management
### Run API server and StableStudio web interface (alpha)
@ -476,8 +477,7 @@ a bowl full of gold bars sitting on a table
- Prompt metadata saved into image file metadata
- Have AI generate captions for images `aimg describe <filename-or-url>`
- Interactive prompt: just run `aimg`
- finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page
## How To
For full command line instructions run `aimg --help`

@ -1,85 +0,0 @@
# Adding a concept to Stable Diffusion
You can use Imaginairy to teach the model a new concept (a person, thing, style, etc) using the `aimg train-concept`
command.
## Requirements
- Graphics card: 3090 or better
- Linux
- A working Imaginairy installation
- a folder of images of the concept you want to teach the model
## Background
To train the model we show it a lot of images of the concept we want to teach it. The problem is the model can easily
overfit to the images we show it. To prevent this we also show it images of the class of thing that is being trained.
Imaginairy will generate the images needed for this before running the training job.
Provided a directory of concept images, a concept token, and a class token, this command will train the model
to generate images of that concept.
This happens in a 3-step process:
1. Cropping and resizing your training images. If --person is set we crop to include the face.
2. Generating a set of class images to train on. This helps prevent overfitting.
3. Training the model on the concept and class images.
The output of this command is a new model weights file that you can use with the --model option.
## Instructions
1. Gather a set of images of the concept you want to train on. The images should show the subject from a variety of angles
and in a variety of situations.
2. Run `aimg train-concept` to train the model.
- Concept label: For a person, firstnamelastname should be fine.
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
or "woman".
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
- CLass images will be generated for you if you do not provide them.
For example, if you were training on photos of a man named bill hamilton you could run the following:
```
aimg train-concept \\
--person \\
--concept-label "photo of billhamilton man" \\
--concept-images-dir ./images/billhamilton \\
--class-label "photo of a man" \\
--class-images-dir ./images/man
```
3. Stop training before it overfits.
- The training script will output checkpoint ckpt files into the logs folder of wherever it is run from. You can also
monitor generated images in the logs/images folder. They will be the ones named "sample"
- I don't have great advice on when to stop training yet. I stopped mine at epoch 62 at it didn't seem quite good enough, at epoch 111 it
produced my face correctly 50% of the time but also seemed overfit in some ways (always placing me in the same clothes or background as training photos).
- You can monitor model training progress in Tensorboard. Run `tensorboard --logdir lightning_logs` and open the link it gives you in your browser.
4. Prune the model to bring the size from 11gb to ~4gb: `aimg prune-ckpt logs/2023-01-15T05-52-06/checkpoints/epoch\=000049.ckpt`. Copy it somewhere
and give it a meaninful name.
## Using the new model
You can reference the model like this in imaginairy:
`imagine --model my-models/billhamilton-man-e111.ckpt`
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
## Disclaimers
- The settings imaginairy uses to train the model are different than other software projects. As such you cannot follow
advice you may read from other tutorials regarding learning rate, epochs, steps, batch size. They are not directly
comparable. In laymans terms the "steps" are much bigger in imaginairy.
- I consider this training feature experimental and don't currently plan to offer support for it. Any further work will
be at my leisure. As a result I may close any reported issues related to this feature.
- You can find a lot more relevant information here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
## Todo
- figure out how to improve consistency of quality from trained model
- train on the depth guided model instead of SD 1.5 since that will enable more consistent output
- figure out metric to use for stopping training
- possibly swap out and randomize backgrounds on training photos so over-fitting does not occur

@ -34,9 +34,7 @@ from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
)
@click.option(
"--control-strength",
help=(
"Strength of the control signal."
),
help=("Strength of the control signal."),
multiple=True,
)
@click.option(

@ -9,7 +9,6 @@ from imaginairy.cli.edit import edit_cmd
from imaginairy.cli.edit_demo import edit_demo_cmd
from imaginairy.cli.imagine import imagine_cmd
from imaginairy.cli.run_api import run_server_cmd
from imaginairy.cli.train import prep_images_cmd, prune_ckpt_cmd, train_concept_cmd
from imaginairy.cli.upscale import upscale_cmd
from imaginairy.cli.videogen import videogen_cmd
@ -46,9 +45,8 @@ aimg.add_command(describe_cmd, name="describe")
aimg.add_command(edit_cmd, name="edit")
aimg.add_command(edit_demo_cmd, name="edit-demo")
aimg.add_command(imagine_cmd, name="imagine")
aimg.add_command(prep_images_cmd, name="prep-images")
aimg.add_command(prune_ckpt_cmd, name="prune-ckpt")
aimg.add_command(train_concept_cmd, name="train-concept")
# aimg.add_command(prep_images_cmd, name="prep-images")
# aimg.add_command(prune_ckpt_cmd, name="prune-ckpt")
aimg.add_command(upscale_cmd, name="upscale")
aimg.add_command(run_server_cmd, name="server")
aimg.add_command(videogen_cmd, name="videogen")

@ -22,7 +22,6 @@ class ModelConfig:
config_path: str
weights_url: str
default_image_size: int
weights_url_full: str = None
forced_attn_precision: str = "default"
default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT
alias: str = None
@ -36,7 +35,6 @@ MODEL_CONFIGS = [
short_name="SD-1.5",
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
weights_url_full="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned.ckpt",
default_image_size=512,
alias="sd15",
),

@ -154,7 +154,6 @@ def get_diffusion_model(
control_weights_locations=None,
half_mode=None,
for_inpainting=False,
for_training=False,
):
"""
Load a diffusion model.
@ -168,7 +167,6 @@ def get_diffusion_model(
half_mode,
for_inpainting,
control_weights_locations=control_weights_locations,
for_training=for_training,
)
except HuggingFaceAuthorizationError as e:
if for_inpainting:
@ -180,7 +178,6 @@ def get_diffusion_model(
config_path,
half_mode,
for_inpainting=False,
for_training=for_training,
control_weights_locations=control_weights_locations,
)
raise
@ -191,7 +188,6 @@ def _get_diffusion_model(
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
for_training=False,
control_weights_locations=None,
):
"""
@ -211,7 +207,6 @@ def _get_diffusion_model(
config_path=config_path,
control_weights_paths=control_weights_locations,
for_inpainting=for_inpainting,
for_training=for_training,
)
# some models need the attention calculated in float32
if model_config is not None:
@ -222,7 +217,6 @@ def _get_diffusion_model(
config_path=config_path,
weights_location=weights_location,
half_mode=half_mode,
for_training=for_training,
)
MOST_RECENTLY_LOADED_MODEL = diffusion_model
if control_weights_locations:
@ -240,7 +234,6 @@ def get_diffusion_model_refiners(
control_weights_locations=None,
dtype=None,
for_inpainting=False,
for_training=False,
):
"""
Load a diffusion model.
@ -254,7 +247,6 @@ def get_diffusion_model_refiners(
for_inpainting,
dtype=dtype,
control_weights_locations=control_weights_locations,
for_training=for_training,
)
except HuggingFaceAuthorizationError as e:
if for_inpainting:
@ -266,7 +258,6 @@ def get_diffusion_model_refiners(
config_path,
dtype=dtype,
for_inpainting=False,
for_training=for_training,
control_weights_locations=control_weights_locations,
)
raise
@ -276,7 +267,6 @@ def _get_diffusion_model_refiners(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
for_inpainting=False,
for_training=False,
control_weights_locations=None,
device=None,
dtype=torch.float16,
@ -291,7 +281,6 @@ def _get_diffusion_model_refiners(
weights_location=weights_location,
config_path=config_path,
for_inpainting=for_inpainting,
for_training=for_training,
device=device,
dtype=dtype,
)
@ -304,7 +293,6 @@ def _get_diffusion_model_refiners_only(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
for_inpainting=False,
for_training=False,
control_weights_locations=None,
device=None,
dtype=torch.float16,
@ -334,7 +322,6 @@ def _get_diffusion_model_refiners_only(
config_path=config_path,
control_weights_paths=control_weights_locations,
for_inpainting=for_inpainting,
for_training=for_training,
)
# some models need the attention calculated in float32
if model_config is not None:
@ -378,11 +365,8 @@ def _get_diffusion_model_refiners_only(
@memory_managed_model("stable-diffusion", memory_usage_mb=1951)
def _load_diffusion_model(config_path, weights_location, half_mode, for_training):
def _load_diffusion_model(config_path, weights_location, half_mode):
model_config = OmegaConf.load(f"{PKG_ROOT}/{config_path}")
if for_training:
model_config.use_ema = True
# model_config.use_scheduler = True
# only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda"
@ -443,7 +427,6 @@ def resolve_model_paths(
config_path=None,
control_weights_paths=None,
for_inpainting=False,
for_training=False,
):
"""Resolve weight and config path if they happen to be shortcuts."""
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
@ -466,13 +449,8 @@ def resolve_model_paths(
if model_metadata_w:
if config_path is None:
config_path = model_metadata_w.config_path
if for_training:
weights_path = model_metadata_w.weights_url_full
if weights_path is None:
msg = "No full training weights configured for this model. Edit the code or subimt a github issue."
raise ValueError(msg)
else:
weights_path = model_metadata_w.weights_url
weights_path = model_metadata_w.weights_url
if model_metadata_c:
config_path = model_metadata_c.config_path

@ -1,533 +0,0 @@
import datetime
import logging
import os
import signal
import time
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
try:
from pytorch_lightning.strategies import DDPStrategy
except ImportError:
# let's not break all of imaginairy just because a training import doesn't exist in an older version of PL
# Use >= 1.6.0 to make this work
DDPStrategy = None
import contextlib
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.utils.data import DataLoader, Dataset
from imaginairy import config
from imaginairy.model_manager import get_diffusion_model
from imaginairy.training_tools.single_concept import SingleConceptDataset
from imaginairy.utils import get_device, instantiate_from_config
mod_logger = logging.getLogger(__name__)
referenced_by_string = [LearningRateMonitor]
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, SingleConceptDataset):
# split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[
# worker_id * split_size : (worker_id + 1) * split_size
# ]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
num_val_workers=0,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = {}
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if num_val_workers is None:
self.num_val_workers = self.num_workers
else:
self.num_val_workers = num_val_workers
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
self.datasets = None
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
}
if self.wrap:
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
pass
else:
pass
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=worker_init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["validation"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_val_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
is_iterable_dataset = False
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["predict"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
mod_logger.info("Stopping execution and saving final checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
with contextlib.suppress(FileNotFoundError):
os.rename(self.logdir, dst)
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_all_val=False,
concept_label=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_all_val = log_all_val
self.concept_label = concept_label
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "logs", "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = (
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
# always generate the concept label
batch["txt"][0] = self.concept_label
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if self.log_all_val and split == "val":
should_log = True
else:
should_log = self.check_frequency(check_idx)
if (
should_log
and (batch_idx % self.batch_freq == 0)
and hasattr(pl_module, "log_images")
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (check_idx % self.batch_freq) == 0 and (
check_idx > 0 or self.log_first_step
):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if (
hasattr(pl_module, "calibrate_grad_norm")
and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0)
and batch_idx > 0
):
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
/ 2**20
)
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
def train_diffusion_model(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
weights_location=config.DEFAULT_MODEL,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
resume=None,
):
"""
Train a diffusion model on a single concept.
accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf
"""
if DDPStrategy is None:
msg = "Please install pytorch-lightning>=1.6.0 to train a model"
raise ImportError(msg)
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model(
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "imaginairy.train.SetupCallback",
"params": {
"resume": False,
"now": now,
"logdir": logdir,
"ckptdir": ckpt_output_dir,
"cfgdir": cfg_output_dir,
},
},
"image_logger": {
"target": "imaginairy.train.ImageLogger",
"params": {
"batch_frequency": 10,
"max_images": 1,
"clamp": True,
"increase_log_steps": False,
"log_first_step": True,
"log_all_val": True,
"concept_label": concept_label,
"log_images_kwargs": {
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 1,
"unconditional_guidance_scale:": 7.5,
"unconditional_guidance_label": [""],
"ddim_steps": 20,
},
},
},
"learning_rate_logger": {
"target": "imaginairy.train.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
}
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckpt_output_dir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 50,
"save_top_k": -1,
"monitor": None,
},
}
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
dataset_config = {
"concept_label": concept_label,
"concept_images_dir": concept_images_dir,
"class_label": class_label,
"class_images_dir": class_images_dir,
"image_transforms": [
{
"target": "torchvision.transforms.Resize",
"params": {"size": 512, "interpolation": 3},
},
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
],
}
data_module_config = {
"batch_size": batch_size,
"num_workers": num_workers,
"num_val_workers": num_val_workers,
"train": {
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
"params": dataset_config,
},
}
trainer = Trainer(
benchmark=True,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg],
gpus=1,
default_root_dir=".",
)
trainer.logdir = logdir
data = DataModuleFromConfig(**data_module_config)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
def melk(*args, **kwargs):
if trainer.global_rank == 0:
mod_logger.info("Summoning checkpoint.")
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
signal.signal(signal.SIGUSR1, melk)
try:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
finally:
mod_logger.info(trainer.profiler.summary())
Loading…
Cancel
Save