feature: finetuning

- feature: finetuning your own image models
- feature: image prep command. crops to face or other interesting parts of photo
- fix: back-compat for hf_hub_download
- feature: add prune-ckpt command
- feature: allow specification of model config file
This commit is contained in:
Bryce 2023-01-01 14:54:49 -08:00 committed by Bryce Drennan
parent 4bc78b9be5
commit 5cc73f6087
29 changed files with 2489 additions and 50 deletions

4
.gitignore vendored
View File

@ -24,5 +24,5 @@ tests/vastai_cli.py
.unison*
*.kgrind
*.pyprof
/imaginairy/enhancers/.polyscope.ini
/imaginairy/enhancers/imgui.ini
**/.polyscope.ini
**/imgui.ini

View File

@ -180,6 +180,7 @@ a bowl full of gold bars sitting on a table
- Edit images by describing the part you want edited (see example above)
- Have AI generate captions for images `aimg describe <filename-or-url>`
- Interactive prompt: just run `aimg`
- 🎉 finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page
## How To
@ -239,6 +240,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
## ChangeLog
**7.4.0**
- feature: 🎉 finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page
- feature: image prep command. crops to face or other interesting parts of photo
- fix: back-compat for hf_hub_download
- feature: add prune-ckpt command
- feature: allow specification of model config file
**7.3.0**
- feature: 🎉 depth-based image-to-image generations (and inpainting)
- fix: k_euler_a produces more consistent images per seed (randomization respects the seed again)
@ -390,7 +399,6 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
## Not Supported
- a GUI. this is a python library
- training
- exploratory features that don't work well
## Todo
@ -403,7 +411,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- https://github.com/CompVis/stable-diffusion/pull/177
- https://github.com/huggingface/diffusers/pull/532/files
- https://github.com/HazyResearch/flash-attention
- xformers improvements https://www.photoroom.com/tech/stable-diffusion-100-percent-faster-with-memory-efficient-attention/
- xformers improvements https://www.photoroom.com/tech/stable-diffusion-100-percent-faster-with-memory-efficient-attention/
- Development Environment
- ✅ add tests
- ✅ set up ci (test/lint/format)
@ -424,9 +432,10 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- Compositional Visual Generation
- https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- https://colab.research.google.com/github/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch/blob/main/notebooks/demo.ipynb#scrollTo=wt_j3uXZGFAS
- negative prompting
- negative prompting
- some syntax to allow it in a text string
- images as actual prompts instead of just init images. is this the same as textual inversion?
- 🚫 images as actual prompts instead of just init images.
- not directly possible due to model architecture.
- requires model fine-tuning since SD1.4 expects 77x768 text encoding input
- https://twitter.com/Buntworthy/status/1566744186153484288
- https://github.com/justinpinkney/stable-diffusion
@ -439,7 +448,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- ✅ inpainting
- https://github.com/Jack000/glid-3-xl-stable
- https://github.com/andreas128/RePaint
- img2img but keeps img stable
- img2img but keeps img stable
- https://www.reddit.com/r/StableDiffusion/comments/xboy90/a_better_way_of_doing_img2img_by_finding_the/
- https://gist.github.com/trygvebw/c71334dd127d537a15e9d59790f7f5e1
- https://github.com/pesser/stable-diffusion/commit/bbb52981460707963e2a62160890d7ecbce00e79
@ -478,6 +487,11 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- Other
- Enhancement pipelines
- text-to-3d https://dreamfusionpaper.github.io/
- https://shihmengli.github.io/3D-Photo-Inpainting/
- https://github.com/thygate/stable-diffusion-webui-depthmap-script/discussions/50
- Depth estimation
- what is SOTA for monocular depth estimation?
- https://github.com/compphoto/BoostingMonocularDepth
- make a video https://github.com/lucidrains/make-a-video-pytorch
- animations
- https://github.com/francislabountyjr/stable-diffusion/blob/main/inferencing_notebook.ipynb
@ -499,6 +513,14 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
- ✅ deploy to pypi
- find similar images https://knn5.laion.ai/?back=https%3A%2F%2Fknn5.laion.ai%2F&index=laion5B&useMclip=false
- https://github.com/vicgalle/stable-diffusion-aesthetic-gradients
- Training
- Finetuning "dreambooth" style
- Textual Inversion
- Performance Improvements
- [ColoassalAI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) - almost got it working but it's not easy enough to install to merit inclusion in imaginairy. We should check back in on this.
- Xformers
- Deepspeed
-
## Notable Stable Diffusion Implementations
- https://github.com/ahrm/UnstableFusion

BIN
data/DejaVuSans.ttf Normal file

Binary file not shown.

85
docs/concept-training.md Normal file
View File

@ -0,0 +1,85 @@
# Adding a concept to Stable Diffusion
You can use Imaginairy to teach the model a new concept (a person, thing, style, etc) using the `aimg train-concept`
command.
## Requirements
- Graphics card: 3090 or better
- Linux
- A working Imaginairy installation
- a folder of images of the concept you want to teach the model
## Background
To train the model we show it a lot of images of the concept we want to teach it. The problem is the model can easily
overfit to the images we show it. To prevent this we also show it images of the class of thing that is being trained.
Imaginairy will generate the images needed for this before running the training job.
Provided a directory of concept images, a concept token, and a class token, this command will train the model
to generate images of that concept.
This happens in a 3-step process:
1. Cropping and resizing your training images. If --person is set we crop to include the face.
2. Generating a set of class images to train on. This helps prevent overfitting.
3. Training the model on the concept and class images.
The output of this command is a new model weights file that you can use with the --model option.
## Instructions
1. Gather a set of images of the concept you want to train on. The images should show the subject from a variety of angles
and in a variety of situations.
2. Run `aimg train-concept` to train the model.
- Concept label: For a person, firstnamelastname should be fine.
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
or "woman".
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
- CLass images will be generated for you if you do not provide them.
For example, if you were training on photos of a man named bill hamilton you could run the following:
```
aimg train-concept \\
--person \\
--concept-label "photo of billhamilton man" \\
--concept-images-dir ./images/billhamilton \\
--class-label "photo of a man" \\
--class-images-dir ./images/man
```
3. Stop training before it overfits.
- The training script will output checkpoint ckpt files into the logs folder of wherever it is run from. You can also
monitor generated images in the logs/images folder. They will be the ones named "sample"
- I don't have great advice on when to stop training yet. I stopped mine at epoch 62 at it didn't seem quite good enough, at epoch 111 it
produced my face correctly 50% of the time but also seemed overfit in some ways (always placing me in the same clothes or background as training photos).
- You can monitor model training progress in Tensorboard. Run `tensorboard --logdir lightning_logs` and open the link it gives you in your browser.
4. Prune the model to bring the size from 11gb to ~4gb: `aimg prune-ckpt logs/2023-01-15T05-52-06/checkpoints/epoch\=000049.ckpt`. Copy it somewhere
and give it a meaninful name.
## Using the new model
You can reference the model like this in imaginairy:
`imagine --model my-models/billhamilton-man-e111.ckpt`
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
## Disclaimers
- The settings imaginairy uses to train the model are different than other software projects. As such you cannot follow
advice you may read from other tutorials regarding learning rate, epochs, steps, batch size. They are not directly
comparable. In laymans terms the "steps" are much bigger in imaginairy.
- I consider this training feature experimental and don't currently plan to offer support for it. Any further work will
be at my leisure. As a result I may close any reported issues related to this feature.
- You can find a lot more relevant information here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
## Todo
- figure out how to improve consistency of quality from trained model
- train on the depth guided model instead of SD 1.5 since that will enable more consistent output
- figure out metric to use for stopping training
- possibly swap out and randomize backgrounds on training photos so over-fitting does not occur

View File

@ -134,6 +134,7 @@ def imagine(
)
model = get_diffusion_model(
weights_location=prompt.model,
config_path=prompt.model_config_path,
half_mode=half_mode,
for_inpainting=prompt.mask_image or prompt.mask_prompt,
)
@ -182,6 +183,7 @@ def imagine(
max_height=prompt.height,
max_width=prompt.width,
)
except PIL.UnidentifiedImageError:
logger.warning(f"Could not load image: {prompt.init_image}")
continue
@ -194,9 +196,16 @@ def imagine(
)
elif prompt.mask_image:
mask_image = prompt.mask_image.convert("L")
mask_image = pillow_fit_image_within(
mask_image,
max_height=prompt.height,
max_width=prompt.width,
convert="L",
)
if mask_image is not None:
log_img(mask_image, "init mask")
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
mask_image = ImageOps.invert(mask_image)
@ -355,6 +364,8 @@ def imagine(
shape=shape,
batch_size=1,
)
# from torch.nn.functional import interpolate
# samples = interpolate(samples, scale_factor=2, mode='nearest')
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -1,5 +1,6 @@
import logging
import math
import os.path
import click
from click_shell import shell
@ -11,6 +12,13 @@ from imaginairy.enhancers.prompt_expansion import expand_prompts
from imaginairy.log_utils import configure_logging
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt
from imaginairy.train import train_diffusion_model
from imaginairy.training_tools.image_prep import (
create_class_images,
get_image_filenames,
prep_images,
)
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
logger = logging.getLogger(__name__)
@ -183,6 +191,12 @@ logger = logging.getLogger(__name__)
show_default=True,
default=config.DEFAULT_MODEL,
)
@click.option(
"--model-config-path",
help="Model config file to use. If a model name is specified, the appropriate config will be used.",
show_default=True,
default=config.DEFAULT_MODEL,
)
@click.option(
"--prompt-library-path",
help="Path to folder containing phrase lists in txt files. Use txt filename in prompt: {_filename_}.",
@ -221,6 +235,7 @@ def imagine_cmd(
caption,
precision,
model_weights_path,
model_config_path,
prompt_library_path,
):
"""Have the AI generate images. alias:imagine."""
@ -282,6 +297,7 @@ def imagine_cmd(
fix_faces_fidelity=fix_faces_fidelity,
tile_mode=_tile_mode,
model=model_weights_path,
model_config_path=model_config_path,
)
prompts.append(prompt)
@ -315,6 +331,240 @@ def describe(image_filepaths):
print(generate_caption(img.copy()))
@click.option(
"--concept-label",
help=(
'The concept you are training on. Usually "a photo of [person or thing] [classname]" is what you should use.'
),
required=True,
)
@click.option(
"--concept-images-dir",
type=click.Path(),
required=True,
help="Where to find the pre-processed concept images to train on.",
)
@click.option(
"--class-label",
help=(
'What class of things does the concept belong to. For example, if you are training on "a painting of a George Washington", '
'you might use "a painting of a man" as the class label. We use this to prevent the model from overfitting.'
),
default="a photo of *",
)
@click.option(
"--class-images-dir",
type=click.Path(),
required=True,
help="Where to find the pre-processed class images to train on.",
)
@click.option(
"--n-class-images",
type=int,
default=300,
help="Number of class images to generate.",
)
@click.option(
"--model-weights-path",
"--model",
"model",
help=f"Model to use. Should be one of {', '.join(MODEL_SHORT_NAMES)}, or a path to custom weights.",
show_default=True,
default=config.DEFAULT_MODEL,
)
@click.option(
"--person",
"is_person",
is_flag=True,
help="Set if images are of a person. Will use face detection and enhancement.",
)
@click.option(
"-y",
"preconfirmed",
is_flag=True,
default=False,
help="Bypass input confirmations.",
)
@click.option(
"--skip-prep",
is_flag=True,
default=False,
help="Skip the image preparation step.",
)
@click.option(
"--skip-class-img-gen",
is_flag=True,
default=False,
help="Skip the class image generation step.",
)
@aimg.command()
def train_concept(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
n_class_images,
model,
is_person,
preconfirmed,
skip_prep,
skip_class_img_gen,
):
"""
Teach the model a new concept (a person, thing, style, etc).
Provided a directory of concept images, a concept token, and a class token, this command will train the model
to generate images of that concept.
\b
This happens in a 3-step process:
1. Cropping and resizing your training images. If --person is set we crop to include the face.
2. Generating a set of class images to train on. This helps prevent overfitting.
3. Training the model on the concept and class images.
The output of this command is a new model weights file that you can use with the --model option.
\b
## Instructions
1. Gather a set of images of the concept you want to train on. They should show the subject from a variety of angles
and in a variety of situations.
2. Train the model.
- Concept label: For a person, firstnamelastname should be fine.
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
or "woman".
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
- CLass images will be generated for you if you do not provide them.
3. Stop training before it overfits. I haven't figured this out yet.
For example, if you were training on photos of a man named bill hamilton you could run the following:
\b
aimg train-concept \\
--person \\
--concept-label "photo of billhamilton man" \\
--concept-images-dir ./images/billhamilton \\
--class-label "photo of a man" \\
--class-images-dir ./images/man
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
You can find a lot of relevant instructions here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
"""
configure_logging()
target_size = 512
# Step 1. Crop and enhance the training images
prepped_images_path = os.path.join(concept_images_dir, "prepped-images")
image_filenames = get_image_filenames(concept_images_dir)
click.secho(
f'\n🤖🧠 Training "{concept_label}" based on {len(image_filenames)} images.\n'
)
if not skip_prep:
msg = (
f"Creating cropped copies of the {len(image_filenames)} concept images\n"
f" Is Person: {is_person}\n"
f" Source: {concept_images_dir}\n"
f" Dest: {prepped_images_path}\n"
)
logger.info(msg)
if not is_person:
click.secho("⚠️ the --person flag was not set. ", fg="yellow")
if not preconfirmed and not click.confirm("Continue?"):
return
prep_images(
images_dir=concept_images_dir, is_person=is_person, target_size=target_size
)
concept_images_dir = prepped_images_path
if not skip_class_img_gen:
# Step 2. Generate class images
class_image_filenames = get_image_filenames(class_images_dir)
images_needed = max(n_class_images - len(class_image_filenames), 0)
logger.info(f"Generating {n_class_images} class images in {class_images_dir}")
logger.info(
f"{len(class_image_filenames)} existing class images found so only generating {images_needed}."
)
if not preconfirmed and not click.confirm("Continue?"):
return
create_class_images(
class_description=class_label,
output_folder=class_images_dir,
num_images=n_class_images,
)
logger.info("Training the model...")
if not preconfirmed and not click.confirm("Continue?"):
return
# Step 3. Train the model
train_diffusion_model(
concept_label=concept_label,
concept_images_dir=concept_images_dir,
class_label=class_label,
class_images_dir=class_images_dir,
weights_location=model,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
)
@click.argument(
"images_dir",
required=True,
)
@click.option(
"--person",
"is_person",
is_flag=True,
help="Set if images are of a person. Will use face detection and enhancement.",
)
@click.option(
"--target-size",
default=512,
type=int,
show_default=True,
)
@aimg.command("prep-images")
def prepare_images(images_dir, is_person, target_size):
"""
Prepare a folder of images for training.
Prepped images will be written to the `prepped-images` subfolder.
All images will be cropped and resized to (default) 512x512.
Upscaling and face enhancement will be applied as needed to smaller images.
Examples:
aimg prep-images --person ./images/selfies
aimg prep-images ./images/toy-train
"""
configure_logging()
prep_images(images_dir=images_dir, is_person=is_person, target_size=target_size)
@click.argument("ckpt_paths", nargs=-1)
@aimg.command("prune-ckpt")
def prune_ckpt(ckpt_paths):
"""
Prune a checkpoint file.
This will remove the optimizer state from the checkpoint file.
This is useful if you want to use the checkpoint file for inference and save a lot of disk space
Example:
aimg prune-ckpt ./path/to/checkpoint.ckpt
"""
click.secho("Pruning checkpoint files...")
configure_logging()
for p in ckpt_paths:
prune_diffusion_ckpt(p)
aimg.add_command(imagine_cmd, name="imagine")
if __name__ == "__main__":

View File

@ -12,6 +12,8 @@ DEFAULT_NEGATIVE_PROMPT = (
"grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art,"
)
SPLITMEM_ENABLED = False
@dataclass
class ModelConfig:
@ -19,6 +21,7 @@ class ModelConfig:
config_path: str
weights_url: str
default_image_size: int
weights_url_full: str = None
forced_attn_precision: str = "default"
@ -34,7 +37,8 @@ MODEL_CONFIGS = [
ModelConfig(
short_name="SD-1.5",
config_path="configs/stable-diffusion-v1.yaml",
weights_url="https://huggingface.co/acheong08/SD-V1-5-cloned/resolve/fc392f6bd4345b80fc2256fa8aded8766b6c629e/v1-5-pruned-emaonly.ckpt",
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
weights_url_full="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned.ckpt",
default_image_size=512,
),
ModelConfig(

View File

@ -7,7 +7,7 @@ model:
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
first_stage_key: "image"
cond_stage_key: "txt"
image_size: 64
channels: 4
@ -20,7 +20,7 @@ model:
scheduler_config: # 10000 warm-up steps
target: imaginairy.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [ 10000 ]
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
f_start: [ 1.e-6 ]
f_max: [ 1. ]

View File

@ -56,6 +56,7 @@ def enhance_faces(img, fidelity=0):
net = codeformer_model()
face_helper = face_restore_helper()
face_helper.clean_all()
image = img.convert("RGB")
np_img = np.array(image, dtype=np.uint8)

View File

@ -0,0 +1,57 @@
import numpy as np
from imaginairy.enhancers.face_restoration_codeformer import face_restore_helper
from imaginairy.roi_utils import resize_roi_coordinates, square_roi_coordinate
def detect_faces(img):
face_helper = face_restore_helper()
face_helper.clean_all()
image = img.convert("RGB")
np_img = np.array(image, dtype=np.uint8)
# rotate to BGR
np_img = np_img[:, :, ::-1]
face_helper.read_image(np_img)
face_helper.get_face_landmarks_5(
only_center_face=False, resize=640, eye_dist_threshold=5
)
face_helper.align_warp_face()
faceboxes = []
for x1, y1, x2, y2, scaling in face_helper.det_faces:
# x1, y1, x2, y2 = x1 * scaling, y1 * scaling, x2 * scaling, y2 * scaling
faceboxes.append((x1, y1, x2, y2))
return faceboxes
def generate_face_crops(face_roi, max_width, max_height):
"""Returns bounding boxes at various zoom levels for faces in the image."""
crops = []
squared_roi = square_roi_coordinate(face_roi, max_width, max_height)
crops.append(resize_roi_coordinates(squared_roi, 1.1, max_width, max_height))
# 1.6 generally enough to capture entire face
base_expanded_roi = resize_roi_coordinates(squared_roi, 1.6, max_width, max_height)
crops.append(base_expanded_roi)
current_width = base_expanded_roi[2] - base_expanded_roi[0]
# some zoomed out variations
for n in range(2):
factor = 1.25 + 0.4 * n
expanded_roi = resize_roi_coordinates(
base_expanded_roi, factor, max_width, max_height, expand_up=False
)
new_width = expanded_roi[2] - expanded_roi[0]
if new_width <= current_width * 1.1:
# if the zoomed out size isn't suffienctly larger (because there is nowhere to zoom out to), stop
break
crops.append(expanded_roi)
current_width = new_width
return crops

View File

@ -151,3 +151,34 @@ def get_random_non_repeating_combination( # noqa
idx = idx // len(sequence)
yield values
n -= sub_n
# future use
prompt_templates = [
# https://www.reddit.com/r/StableDiffusion/comments/ya4zxm/dreambooth_is_crazy_prompts_workflow_in_comments/
"cinematic still of #prompt-token# as rugged warrior, threatening xenomorph, alien movie (1986),ultrarealistic",
"colorful cinematic still of #prompt-token#, armor, cyberpunk,background made of brain cells, back light, organic, art by greg rutkowski, ultrarealistic, leica 30mm",
"colorful cinematic still of #prompt-token#, armor, cyberpunk, with a xenonorph, in alien movie (1986),background made of brain cells, organic, ultrarealistic, leic 30mm",
"colorful cinematic still of #prompt-token#, #prompt-token# with long hair, color lights, on stage, ultrarealistic",
"colorful portrait of #prompt-token# with dark glasses as eminem, gold chain necklace, relfective puffer jacket, short white hair, in front of music shop,ultrarealistic, leica 30mm",
"colorful photo of #prompt-token# as kurt cobain with glasses, on stage, lights, ultrarealistic, leica 30mm",
"impressionist painting of ((#prompt-token#)) by Daniel F Gerhartz, ((#prompt-token# painted in an impressionist style)), nature, trees",
"pencil sketch of #prompt-token#, #prompt-token#, #prompt-token#, inspired by greg rutkowski, digital art by artgem",
"photo, colorful cinematic still of #prompt-token#, organic armor,cyberpunk,background brain cells mesh, art by greg rutkowski",
"photo, colorful cinematic still of #prompt-token# with organic armor, cyberpunk background, #prompt-token#, greg rutkowski",
"photo of #prompt-token# astronaut, astronaut, glasses, helmet in alien world abstract oil painting, greg rutkowski, detailed face",
"photo of #prompt-token# as firefighter, helmet, ultrarealistic, leica 30mm",
"photo of #prompt-token#, bowler hat, in django unchained movie, ultrarealistic, leica 30mm",
"photo of #prompt-token# as serious spiderman with glasses, ultrarealistic, leica 30mm",
"photo of #prompt-token# as steampunk warrior, neon organic vines, glasses, digital painting",
"photo of #prompt-token# as supermario with glassesm mustach, blue overall, red short,#prompt-token#,#prompt-token#. ultrarealistic, leica 30mm",
"photo of #prompt-token# as targaryen warrior with glasses, long white hair, armor, ultrarealistic, leica 30mm",
"portrait of #prompt-token# as knight, with glasses white eyes, white mid hair, scar on face, handsome, elegant, intricate, headshot, highly detailed, digital",
"portrait of #prompt-token# as hulk, handsome, elegant, intricate luminescent cyberpunk background, headshot, highly detailed, digital painting",
"portrait of #prompt-token# as private eye detective, intricate, war torn, highly detailed, digital painting, concept art, smooth, sharp focus",
# https://publicprompts.art/
"Retro comic style artwork, highly detailed #prompt-token#, comic book cover, symmetrical, vibrant",
"Closeup face portrait of #prompt-token# wearing crown, smooth soft skin, big dreamy eyes, beautiful intricate colored hair, symmetrical, anime wide eyes, soft lighting, detailed face, by makoto shinkai, stanley artgerm lau, wlop, rossdraws, concept art, digital painting, looking into camera"
"highly detailed portrait brycedrennan man in gta v, unreal engine, fantasy art by greg rutkowski, loish, rhads, ferdinand knab, makoto shinkai and lois van baarle, ilya kuvshinov, rossdraws, tom bagshaw, global illumination, radiant light, detailed and intricate environment "
"brycedrennan man: a highly detailed uncropped full-color epic corporate portrait headshot photograph. best portfolio photoraphy photo winner, meticulous detail, hyperrealistic, centered uncropped symmetrical beautiful masculine facial features, atmospheric, photorealistic texture, canon 5D mark III photo, professional studio lighting, aesthetic, very inspirational, motivational. ByKaren L Richard Photography, Photoweb, Splento, Americanoize, Lemonlight",
]

View File

@ -9,8 +9,10 @@ from PIL import Image
from imaginairy.utils import get_device
def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=512):
image = image.convert("RGB")
def pillow_fit_image_within(
image: PIL.Image.Image, max_height=512, max_width=512, convert="RGB"
):
image = image.convert(convert)
w, h = image.size
resize_ratio = 1
if w > max_width or h > max_height:

133
imaginairy/lr_scheduler.py Normal file
View File

@ -0,0 +1,133 @@
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0.
"""
def __init__(
self,
warm_up_steps,
lr_min,
lr_max,
lr_start,
max_decay_steps,
verbosity_interval=0,
):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.0
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
t = (n - self.lr_warm_up_steps) / (
self.lr_max_decay_steps - self.lr_warm_up_steps
)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
):
assert (
len(warm_up_steps)
== len(f_min)
== len(f_max)
== len(f_start)
== len(cycle_lengths)
)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.0
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
return None
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
t = (n - self.lr_warm_up_steps[cycle]) / (
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
)
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi)
)
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
] * n + self.f_start[cycle]
self.last_f = f
return f
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
self.cycle_lengths[cycle] - n
) / (self.cycle_lengths[cycle])
self.last_f = f
return f

View File

@ -1,12 +1,15 @@
import gc
import inspect
import logging
import os
import sys
import urllib.parse
from functools import wraps
import requests
import torch
from huggingface_hub import hf_hub_download, try_to_load_from_cache
from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import try_to_load_from_cache
from omegaconf import OmegaConf
from transformers.utils.hub import HfFolder
@ -30,11 +33,12 @@ class HuggingFaceAuthorizationError(RuntimeError):
class MemoryAwareModel:
"""Wraps a model to allow dynamic loading/unloading as needed."""
def __init__(self, config_path, weights_path, half_mode=None):
def __init__(self, config_path, weights_path, half_mode=None, for_training=False):
self._config_path = config_path
self._weights_path = weights_path
self._half_mode = half_mode
self._model = None
self._for_training = for_training
LOADED_MODELS[(self._config_path, self._weights_path)] = self
@ -47,9 +51,13 @@ class MemoryAwareModel:
# unload all models in LOADED_MODELS
for model in LOADED_MODELS.values():
model.unload_model()
model_config = OmegaConf.load(f"{PKG_ROOT}/{self._config_path}")
if self._for_training:
model_config.use_ema = True
# model_config.use_scheduler = True
model = load_model_from_config(
config=OmegaConf.load(f"{PKG_ROOT}/{self._config_path}"),
config=model_config,
weights_location=self._weights_path,
)
@ -75,7 +83,7 @@ class MemoryAwareModel:
def load_model_from_config(config, weights_location):
if weights_location.startswith("http"):
ckpt_path = get_cached_url_path(weights_location)
ckpt_path = get_cached_url_path(weights_location, category="weights")
else:
ckpt_path = weights_location
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
@ -94,7 +102,7 @@ def load_model_from_config(config, weights_location):
if weights_location.startswith("http"):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location)
ckpt_path = get_cached_url_path(weights_location, category="weights")
pl_sd = torch.load(ckpt_path, map_location="cpu")
if pl_sd is None:
raise e
@ -119,6 +127,7 @@ def get_diffusion_model(
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
for_training=False,
):
"""
Load a diffusion model.
@ -127,7 +136,11 @@ def get_diffusion_model(
"""
try:
return _get_diffusion_model(
weights_location, config_path, half_mode, for_inpainting
weights_location,
config_path,
half_mode,
for_inpainting,
for_training=for_training,
)
except HuggingFaceAuthorizationError as e:
if for_inpainting:
@ -135,7 +148,11 @@ def get_diffusion_model(
f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}"
)
return _get_diffusion_model(
iconfig.DEFAULT_MODEL, config_path, half_mode, for_inpainting=False
iconfig.DEFAULT_MODEL,
config_path,
half_mode,
for_inpainting=False,
for_training=for_training,
)
raise e
@ -145,6 +162,7 @@ def _get_diffusion_model(
config_path="configs/stable-diffusion-v1.yaml",
half_mode=None,
for_inpainting=False,
for_training=False,
):
"""
Load a diffusion model.
@ -152,25 +170,12 @@ def _get_diffusion_model(
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
global MOST_RECENTLY_LOADED_MODEL # noqa
model_config = None
if weights_location is None:
weights_location = iconfig.DEFAULT_MODEL
if (
for_inpainting
and f"{weights_location}-inpaint" in iconfig.MODEL_CONFIG_SHORTCUTS
):
model_config = iconfig.MODEL_CONFIG_SHORTCUTS[f"{weights_location}-inpaint"]
config_path, weights_location = (
model_config.config_path,
model_config.weights_url,
)
elif weights_location in iconfig.MODEL_CONFIG_SHORTCUTS:
model_config = iconfig.MODEL_CONFIG_SHORTCUTS[weights_location]
config_path, weights_location = (
model_config.config_path,
model_config.weights_url,
)
model_config, weights_location, config_path = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
for_inpainting=for_inpainting,
for_training=for_training,
)
# some models need the attention calculated in float32
if model_config is not None:
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
@ -180,7 +185,10 @@ def _get_diffusion_model(
key = (config_path, weights_location)
if key not in LOADED_MODELS:
MemoryAwareModel(
config_path=config_path, weights_path=weights_location, half_mode=half_mode
config_path=config_path,
weights_path=weights_location,
half_mode=half_mode,
for_training=for_training,
)
model = LOADED_MODELS[key]
@ -190,6 +198,41 @@ def _get_diffusion_model(
return model
def resolve_model_paths(
weights_path=iconfig.DEFAULT_MODEL,
config_path=None,
for_inpainting=False,
for_training=False,
):
"""Resolve weight and config path if they happen to be shortcuts."""
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None)
if for_inpainting:
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{weights_path}-inpaint", model_metadata_w
)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{config_path}-inpaint", model_metadata_c
)
if model_metadata_w:
if config_path is None:
config_path = model_metadata_w.config_path
if for_training:
weights_path = model_metadata_w.weights_url_full
if weights_path is None:
raise ValueError(
"No full training weights configured for this model. Edit the code or subimt a github issue."
)
else:
weights_path = model_metadata_w.weights_url
if model_metadata_c:
config_path = model_metadata_c.config_path
model_metadata = model_metadata_w or model_metadata_c
return model_metadata, weights_path, config_path
def get_model_default_image_size(weights_location):
model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None)
if model_config:
@ -209,12 +252,12 @@ def get_cache_dir():
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy", "weights")
return os.path.join(xdg_cache_home, "imaginairy")
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
def get_cached_url_path(url):
def get_cached_url_path(url, category=None):
"""
Gets the contents of a url, but caches the response indefinitely.
@ -231,6 +274,8 @@ def get_cached_url_path(url):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
if category:
dest = os.path.join(dest, category)
os.makedirs(dest, exist_ok=True)
dest_path = os.path.join(dest, filename)
if os.path.exists(dest_path):
@ -260,6 +305,20 @@ def check_huggingface_url_authorized(url):
return None
@wraps(_hf_hub_download)
def hf_hub_download(*args, **kwargs):
"""
backwards compatible wrapper for huggingface's hf_hub_download.
they changed ther argument name from `use_auth_token` to `token`
"""
arg_names = inspect.getfullargspec(_hf_hub_download)
if "use_auth_token" in arg_names.args and "token" in kwargs:
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
def huggingface_cached_path(url):
# bypass all the HEAD calls done by the default `cached_path`
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)

View File

@ -18,6 +18,7 @@ except ImportError:
XFORMERS_IS_AVAILBLE = False
ALLOW_SPLITMEM = True
ATTENTION_PRECISION_OVERRIDE = "default"
@ -174,10 +175,11 @@ class CrossAttention(nn.Module):
# mask = _global_mask_hack.to(torch.bool)
if get_device() == "cuda" or "mps" in get_device():
if not XFORMERS_IS_AVAILBLE:
if not XFORMERS_IS_AVAILBLE and ALLOW_SPLITMEM:
return self.forward_splitmem(x, context=context, mask=mask)
h = self.heads
# print(x.shape)
q = self.to_q(x)
context = context if context is not None else x
@ -193,7 +195,8 @@ class CrossAttention(nn.Module):
sim = einsum("b i d, b j d -> b i j", q, k)
else:
sim = einsum("b i d, b j d -> b i j", q, k)
# print(sim.shape)
# print("*" * 100)
del q, k
# if mask is not None:
# if sim.shape[2] == 320 and False:

View File

@ -14,12 +14,17 @@ import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange, repeat
from omegaconf import ListConfig
from PIL import Image, ImageDraw, ImageFont
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from torch.optim.lr_scheduler import LambdaLR
from torchvision.utils import make_grid
from tqdm import tqdm
from imaginairy.modules.attention import CrossAttention
from imaginairy.modules.autoencoder import AutoencoderKL, IdentityFirstStage
from imaginairy.modules.diffusion.util import (
extract_into_tensor,
make_beta_schedule,
@ -27,12 +32,39 @@ from imaginairy.modules.diffusion.util import (
)
from imaginairy.modules.distributions import DiagonalGaussianDistribution
from imaginairy.modules.ema import LitEma
from imaginairy.samplers.kdiff import DPMPP2MSampler
from imaginairy.utils import instantiate_from_config
logger = logging.getLogger(__name__)
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = []
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
)
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def disabled_train(self):
"""
Overwrite model.train with this function to make sure train/eval mode
@ -605,7 +637,9 @@ class DDPM(pl.LightningModule):
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
def log_images(
self, batch, N=8, n_row=2, *, sample=True, return_keys=None, **kwargs
):
log = {}
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
@ -678,6 +712,7 @@ class LatentDiffusion(DDPM):
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
unet_trainable=True,
**kwargs,
):
self.num_timesteps_cond = (
@ -695,6 +730,7 @@ class LatentDiffusion(DDPM):
super().__init__(conditioning_key=conditioning_key, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.unet_trainable = unet_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
@ -726,6 +762,7 @@ class LatentDiffusion(DDPM):
m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa
m, nn.Conv2d
)
self.tile_mode(tile_mode=False)
def tile_mode(self, tile_mode):
"""For creating seamless tiles."""
@ -1005,7 +1042,7 @@ class LatentDiffusion(DDPM):
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ["caption", "coordinates_bbox"]:
if cond_key in ["caption", "coordinates_bbox", "txt"]:
xc = batch[cond_key]
elif cond_key == "class_label":
xc = batch
@ -1100,6 +1137,24 @@ class LatentDiffusion(DDPM):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(
0, self.num_timesteps, (x.shape[0],), device=self.device
).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
@ -1244,6 +1299,42 @@ class LatentDiffusion(DDPM):
return x_recon
def p_losses(self, x_start, cond, t, noise=None): # noqa
noise = noise if noise is not None else torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = "train" if self.training else "val"
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
# t sometimes on wrong device. not sure why
logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
loss_dict.update({"logvar": self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
loss += self.original_elbo_weight * loss_vlb
loss_dict.update({f"{prefix}/loss": loss})
return loss, loss_dict
def p_mean_variance(
self,
x,
@ -1346,6 +1437,287 @@ class LatentDiffusion(DDPM):
* noise
)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
sampler = DPMPP2MSampler(self)
shape = (batch_size, self.channels, self.image_size, self.image_size)
uncond = kwargs.get("unconditional_conditioning")
if uncond is None:
uncond = self.get_unconditional_conditioning(batch_size, "")
positive_conditioning = {
"c_concat": [],
"c_crossattn": [cond],
}
neutral_conditioning = {
"c_concat": [],
"c_crossattn": [uncond],
}
samples = sampler.sample(
num_steps=ddim_steps,
positive_conditioning=positive_conditioning,
neutral_conditioning=neutral_conditioning,
guidance_scale=kwargs.get("unconditional_guidance_scale", 5.0),
shape=shape,
batch_size=1,
)
return samples, []
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, (dict, list)):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
@torch.no_grad()
def log_images(
self,
batch,
N=8,
n_row=4,
sample=True,
ddim_steps=50,
ddim_eta=1.0,
return_keys=None,
quantize_denoised=True,
inpaint=True,
plot_denoise_rows=False,
plot_progressive_rows=True,
plot_diffusion_rows=True,
unconditional_guidance_scale=1.0,
unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs,
):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = {}
z, c, x, xrec, xc = self.get_input(
batch,
self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N,
)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption", "txt"]:
xc = log_txt_as_img(
(x.shape[2], x.shape[3]),
batch[self.cond_stage_key],
size=x.shape[2] // 25,
)
log["conditioning"] = xc
elif self.cond_stage_key == "class_label":
# xc = log_txt_as_img(
# (x.shape[2], x.shape[3]),
# batch["human_label"],
# size=x.shape[2] // 25,
# )
log["conditioning"] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with ema_scope("Sampling"):
samples, z_denoise_row = self.sample_log(
cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if (
quantize_denoised
and not isinstance(self.first_stage_model, AutoencoderKL)
and not isinstance(self.first_stage_model, IdentityFirstStage)
):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(
cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
quantize_denoised=True,
)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if unconditional_guidance_scale > 1.0:
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
# uc = torch.zeros_like(c)
with ema_scope("Sampling with classifier-free guidance"):
samples_cfg, _ = self.sample_log(
cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
] = x_samples_cfg
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(
cond=c,
batch_size=N,
ddim=use_ddim,
eta=ddim_eta,
ddim_steps=ddim_steps,
x0=z[:N],
mask=mask,
)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
mask = 1.0 - mask
with ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(
cond=c,
batch_size=N,
ddim=use_ddim,
eta=ddim_eta,
ddim_steps=ddim_steps,
x0=z[:N],
mask=mask,
)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(
c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N,
)
prog_row = self._get_denoise_row_from_list(
progressives, desc="Progressive Generation"
)
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = []
if self.unet_trainable == "attn":
logger.info("Training only unet attention layers")
for n, m in self.model.named_modules():
if isinstance(m, CrossAttention) and n.endswith("attn2"):
params.extend(m.parameters())
elif self.unet_trainable is True or self.unet_trainable == "all":
logger.info("Training the full unet")
params = list(self.model.parameters())
else:
raise ValueError(
f"Unrecognised setting for unet_trainable: {self.unet_trainable}"
)
if self.cond_stage_trainable:
logger.info(
f"{self.__class__.__name__}: Also optimizing conditioner params!"
)
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
logger.info("Diffusion model optimizing logvar")
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert "target" in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
logger.info("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
"interval": "step",
"frequency": 1,
}
]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) # noqa
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
return x
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):

View File

@ -12,7 +12,7 @@ class BaseModel(torch.nn.Module): # noqa
Args:
path (str): file path
"""
ckpt_path = get_cached_url_path(config.midas_url)
ckpt_path = get_cached_url_path(config.midas_url, category="weights")
parameters = torch.load(ckpt_path, map_location=torch.device("cpu"))
if "optimizer" in parameters:

103
imaginairy/roi_utils.py Normal file
View File

@ -0,0 +1,103 @@
import logging
logger = logging.getLogger(__name__)
def square_roi_coordinate(roi, max_width, max_height, best_effort=False):
"""Given a region of interest, returns a square region of interest."""
x1, y1, x2, y2 = roi
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
roi_width = x2 - x1
roi_height = y2 - y1
if roi_width < roi_height:
diff = roi_height - roi_width
x1 -= int(round(diff / 2))
x2 += roi_height - (x2 - x1)
elif roi_height < roi_width:
diff = roi_width - roi_height
y1 -= int(round(diff / 2))
y2 += roi_width - (y2 - y1)
x1, y1, x2, y2 = move_roi_into_bounds(
(x1, y1, x2, y2), max_width, max_height, force=best_effort
)
width = x2 - x1
height = y2 - y1
if not best_effort and width != height:
raise RuntimeError(f"ROI is not square: {width}x{height}")
return x1, y1, x2, y2
def resize_roi_coordinates(
roi, expansion_factor, max_width, max_height, expand_up=True
):
"""
Resize a region of interest while staying within the bounds.
setting expand_up to False will prevent the ROI from expanding upwards, which is useful when
expanding something like a face roi to capture more of the person instead of empty space above them.
"""
x1, y1, x2, y2 = roi
side_length_x = x2 - x1
side_length_y = y2 - y1
max_expansion_factor = min(max_height / side_length_y, max_width / side_length_x)
expansion_factor = min(expansion_factor, max_expansion_factor)
expansion_x = int(round(side_length_x * expansion_factor - side_length_x))
expansion_x_a = int(round(expansion_x / 2))
expansion_x_b = expansion_x - expansion_x_a
x1 -= expansion_x_a
x2 += expansion_x_b
expansion_y = int(round(side_length_y * expansion_factor - side_length_y))
if expand_up:
expansion_y_a = int(round(expansion_y / 2))
expansion_y_b = expansion_y - expansion_y_a
y1 -= expansion_y_a
y2 += expansion_y_b
else:
y2 += expansion_y
x1, y1, x2, y2 = move_roi_into_bounds((x1, y1, x2, y2), max_width, max_height)
return x1, y1, x2, y2
class RoiNotInBoundsError(ValueError):
"""Error raised when a ROI is not within the bounds of the image."""
def move_roi_into_bounds(roi, max_width, max_height, force=False):
"""Move a region of interest into the bounds of the image."""
x1, y1, x2, y2 = roi
# move the ROI within the image boundaries
if x1 < 0:
x2 -= x1
x1 = 0
if y1 < 0:
y2 -= y1
y1 = 0
if x2 > max_width:
x1 -= x2 - max_width
x2 = max_width
if y2 > max_height:
y1 -= y2 - max_height
y2 = max_height
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
# Force ROI to fit within image boundaries (sacrificing size and aspect ratio of ROI)
if force:
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(max_width, x2)
y2 = min(max_height, y2)
if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height:
roi_width = x2 - x1
roi_height = y2 - y1
raise RoiNotInBoundsError(
f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
)
return x1, y1, x2, y2

View File

@ -47,6 +47,13 @@ class DDIMSampler(ImageSampler):
t_start=None,
quantize_x0=False,
):
# print("Sampling with DDIM")
# print("num_steps", num_steps)
# print("shape", shape)
# print("neutral_conditioning", neutral_conditioning)
# print("positive_conditioning", positive_conditioning)
# print("guidance_scale", guidance_scale)
# print("batch_size", batch_size)
schedule = NoiseSchedule(
model_num_timesteps=self.model.num_timesteps,
model_alphas_cumprod=self.model.alphas_cumprod,

View File

@ -56,14 +56,14 @@ class LazyLoadingImage:
if self._lazy_filepath:
self._img = Image.open(self._lazy_filepath)
logger.info(
logger.debug(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
)
elif self._lazy_url:
self._img = Image.open(
requests.get(self._lazy_url, stream=True, timeout=60).raw
)
logger.info(
logger.debug(
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
)
# fix orientation
@ -113,6 +113,7 @@ class ImaginePrompt:
conditioning=None,
tile_mode="",
model=config.DEFAULT_MODEL,
model_config_path=None,
):
self.prompts = self.process_prompt_input(prompt)
@ -165,6 +166,7 @@ class ImaginePrompt:
if self.model == "SD-2.0-v" and self.sampler_type == SamplerName.PLMS:
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
self.model_config_path = model_config_path
@property
def prompt_text(self):

520
imaginairy/train.py Normal file
View File

@ -0,0 +1,520 @@
import datetime
import logging
import os
import signal
import time
from functools import partial
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.utils.data import DataLoader, Dataset
from imaginairy import config
from imaginairy.model_manager import get_diffusion_model
from imaginairy.training_tools.single_concept import SingleConceptDataset
from imaginairy.utils import get_device, instantiate_from_config
mod_logger = logging.getLogger(__name__)
referenced_by_string = [LearningRateMonitor]
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, SingleConceptDataset):
# split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
# dataset.sample_ids = dataset.valid_ids[
# worker_id * split_size : (worker_id + 1) * split_size
# ]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=False,
num_val_workers=0,
):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = {}
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if num_val_workers is None:
self.num_val_workers = self.num_workers
else:
self.num_val_workers = num_val_workers
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
self.datasets = None
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = {
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
}
if self.wrap:
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
def _train_dataloader(self):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
pass
else:
pass
return DataLoader(
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
worker_init_fn=worker_init_fn,
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["validation"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_val_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
is_iterable_dataset = False
# do not shuffle dataloader for iterable dataset
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
shuffle=shuffle,
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets["predict"], SingleConceptDataset)
or self.use_worker_init_fn
):
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
)
class SetupCallback(Callback):
def __init__(
self,
resume,
now,
logdir,
ckptdir,
cfgdir,
):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
mod_logger.info("Stopping execution and saving final checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(
self,
batch_frequency,
max_images,
clamp=True,
increase_log_steps=True,
rescale=True,
disabled=False,
log_on_batch_idx=False,
log_first_step=False,
log_images_kwargs=None,
log_all_val=False,
concept_label=None,
):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
self.disabled = disabled
self.log_on_batch_idx = log_on_batch_idx
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.log_first_step = log_first_step
self.log_all_val = log_all_val
self.concept_label = concept_label
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "logs", "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = (
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
# always generate the concept label
batch["txt"][0] = self.concept_label
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if self.log_all_val and split == "val":
should_log = True
else:
should_log = self.check_frequency(check_idx)
if (
should_log
and (batch_idx % self.batch_freq == 0)
and hasattr(pl_module, "log_images")
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (check_idx % self.batch_freq) == 0 and (
check_idx > 0 or self.log_first_step
):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
if "cuda" in get_device():
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
torch.cuda.synchronize(trainer.strategy.root_device.index)
self.start_time = time.time() # noqa
def on_train_epoch_end(self, trainer, pl_module): # noqa
if "cuda" in get_device():
torch.cuda.synchronize(trainer.strategy.root_device.index)
max_memory = (
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
/ 2**20
)
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
def train_diffusion_model(
concept_label,
concept_images_dir,
class_label,
class_images_dir,
weights_location=config.DEFAULT_MODEL,
logdir="logs",
learning_rate=1e-6,
accumulate_grad_batches=32,
resume=None,
):
batch_size = 1
seed = 23
num_workers = 1
num_val_workers = 0
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
logdir = os.path.join(logdir, now)
ckpt_output_dir = os.path.join(logdir, "checkpoints")
cfg_output_dir = os.path.join(logdir, "configs")
seed_everything(seed)
model = get_diffusion_model( # noqa
weights_location=weights_location, half_mode=False, for_training=True
)._model
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "imaginairy.train.SetupCallback",
"params": {
"resume": False,
"now": now,
"logdir": logdir,
"ckptdir": ckpt_output_dir,
"cfgdir": cfg_output_dir,
},
},
"image_logger": {
"target": "imaginairy.train.ImageLogger",
"params": {
"batch_frequency": 10,
"max_images": 1,
"clamp": True,
"increase_log_steps": False,
"log_first_step": True,
"log_all_val": True,
"concept_label": concept_label,
"log_images_kwargs": {
"use_ema_scope": True,
"inpaint": False,
"plot_progressive_rows": False,
"plot_diffusion_rows": False,
"N": 1,
"unconditional_guidance_scale:": 7.5,
"unconditional_guidance_label": [""],
"ddim_steps": 20,
},
},
},
"learning_rate_logger": {
"target": "imaginairy.train.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
}
default_modelckpt_cfg = {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckpt_output_dir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
"every_n_train_steps": 50,
"save_top_k": -1,
"monitor": None,
},
}
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
dataset_config = {
"concept_label": concept_label,
"concept_images_dir": concept_images_dir,
"class_label": class_label,
"class_images_dir": class_images_dir,
"image_transforms": [
{
"target": "torchvision.transforms.Resize",
"params": {"size": 512, "interpolation": 3},
},
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
],
}
data_module_config = {
"batch_size": batch_size,
"num_workers": num_workers,
"num_val_workers": num_val_workers,
"train": {
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
"params": dataset_config,
},
}
trainer = Trainer(
benchmark=True,
num_sanity_val_steps=0,
accumulate_grad_batches=accumulate_grad_batches,
strategy=DDPStrategy(),
callbacks=[
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa
],
gpus=1,
default_root_dir=".",
)
trainer.logdir = logdir
data = DataModuleFromConfig(**data_module_config)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
def melk(*args, **kwargs):
if trainer.global_rank == 0:
mod_logger.info("Summoning checkpoint.")
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
signal.signal(signal.SIGUSR1, melk)
try:
try:
trainer.fit(model, data)
except Exception:
melk()
raise
finally:
mod_logger.info(trainer.profiler.summary())

View File

View File

@ -0,0 +1,155 @@
import logging
import os
import os.path
import re
from PIL import Image
from tqdm import tqdm
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.vendored.smart_crop import SmartCrop
logger = logging.getLogger(__name__)
def get_image_filenames(folder):
filenames = []
for filename in os.listdir(folder):
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
continue
if filename.startswith("."):
continue
filenames.append(filename)
return filenames
def prep_images(
images_dir, is_person=False, output_folder_name="prepped-images", target_size=512
):
"""
Crops and resizes a directory of images in preparation for training.
If is_person=True, it will detect the face and produces several crops at different zoom levels. For crops that
are too small, it will use the face restoration model to enhance the faces.
For non-person images, it will use the smartcrop algorithm to crop the image to the most interesting part. If the
input image is too small it will be upscaled.
Prep will go a lot faster if all the images are big enough to not require upscaling.
"""
output_folder = os.path.join(images_dir, output_folder_name)
os.makedirs(output_folder, exist_ok=True)
logger.info(f"Prepping images in {images_dir} to {output_folder}")
image_filenames = get_image_filenames(images_dir)
pbar = tqdm(image_filenames)
for filename in pbar:
pbar.set_description(filename)
input_path = os.path.join(images_dir, filename)
img = LazyLoadingImage(filepath=input_path).convert("RGB")
if is_person:
face_rois = detect_faces(img)
if len(face_rois) == 0:
logger.info(f"No faces detected in image {filename}, skipping")
continue
if len(face_rois) > 1:
logger.info(f"Multiple faces detected in image {filename}, skipping")
continue
face_roi = face_rois[0]
face_roi_crops = generate_face_crops(
face_roi, max_width=img.width, max_height=img.height
)
for n, face_roi_crop in enumerate(face_roi_crops):
cropped_output_path = os.path.join(
output_folder, f"{filename}_[alt-{n:02d}].jpg"
)
if os.path.exists(cropped_output_path):
logger.debug(
f"Skipping {cropped_output_path} because it already exists"
)
continue
x1, y1, x2, y2 = face_roi_crop
crop_width = x2 - x1
crop_height = y2 - y1
if crop_width != crop_height:
logger.info(
f"Face ROI crop for {filename} {crop_width}x{crop_height} is not square, skipping"
)
continue
cropped_img = img.crop(face_roi_crop)
if crop_width < target_size:
logger.info(f"Upscaling {filename} {face_roi_crop}")
cropped_img = cropped_img.resize(
(target_size, target_size), resample=Image.Resampling.LANCZOS
)
cropped_img = enhance_faces(cropped_img, fidelity=1)
else:
cropped_img = cropped_img.resize(
(target_size, target_size), resample=Image.Resampling.LANCZOS
)
cropped_img.save(cropped_output_path, quality=95)
else:
# scale image so that largest dimension is target_size
n = 0
cropped_output_path = os.path.join(output_folder, f"{filename}_{n}.jpg")
if os.path.exists(cropped_output_path):
logger.debug(
f"Skipping {cropped_output_path} because it already exists"
)
continue
if img.width < target_size or img.height < target_size:
# upscale the image if it's too small
logger.info(f"Upscaling {filename}")
img = upscale_image(img)
if img.width > img.height:
scale_factor = target_size / img.height
else:
scale_factor = target_size / img.width
# downscale so shortest side is target_size
new_width = int(round(img.width * scale_factor))
new_height = int(round(img.height * scale_factor))
img = img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
result = SmartCrop().crop(img, width=target_size, height=target_size)
box = (
result["top_crop"]["x"],
result["top_crop"]["y"],
result["top_crop"]["width"] + result["top_crop"]["x"],
result["top_crop"]["height"] + result["top_crop"]["y"],
)
cropped_image = img.crop(box)
cropped_image.save(cropped_output_path, quality=95)
logger.info(f"Image Prep complete. Review output at {output_folder}")
def prompt_normalized(prompt):
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "-", prompt)[:130]
def create_class_images(class_description, output_folder, num_images=200):
"""
Generate images of class_description.
"""
existing_images = get_image_filenames(output_folder)
existing_image_count = len(existing_images)
class_slug = prompt_normalized(class_description)
while existing_image_count < num_images:
prompt = ImaginePrompt(class_description, steps=20)
result = list(imagine([prompt]))[0]
if result.is_nsfw:
continue
dest = os.path.join(
output_folder, f"{existing_image_count:03d}_{class_slug}.jpg"
)
result.save(dest)
existing_image_count += 1

View File

@ -0,0 +1,31 @@
import logging
import os
import torch
logger = logging.getLogger(__name__)
def prune_diffusion_ckpt(ckpt_path, dst_path=None):
if dst_path is None:
dst_path = f"{os.path.splitext(ckpt_path)[0]}-pruned.ckpt"
data = torch.load(ckpt_path, map_location="cpu")
new_data = prune_model_data(data)
torch.save(new_data, dst_path)
size_initial = os.path.getsize(ckpt_path)
newsize = os.path.getsize(dst_path)
msg = (
f"New ckpt size: {newsize * 1e-9:.2f} GB. "
f"Saved {(size_initial - newsize) * 1e-9:.2f} GB by removing optimizer states"
)
logger.info(msg)
def prune_model_data(data):
skip_keys = {"optimizer_states"}
new_data = {k: v for k, v in data.items() if k not in skip_keys}
return new_data

View File

@ -0,0 +1,135 @@
import os
import random
import re
from abc import abstractmethod
from collections import defaultdict
from einops import rearrange
from omegaconf import ListConfig
from PIL import Image
from torch.utils.data import Dataset, IterableDataset
from torchvision.transforms import transforms
from imaginairy.utils import instantiate_from_config
class Txt2ImgIterableBaseDataset(IterableDataset):
"""
Define an interface to make the IterableDatasets for text2img data chainable.
"""
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
def _rearrange(x):
return rearrange(x * 2.0 - 1.0, "c h w -> h w c")
class SingleConceptDataset(Dataset):
"""
Dataset for finetuning a model on a single concept.
Similar to "dreambooth"
"""
def __init__(
self,
concept_label,
class_label,
concept_images_dir,
class_images_dir,
image_transforms=None,
):
self.concept_label = concept_label
self.class_label = class_label
self.concept_images_dir = concept_images_dir
self.class_images_dir = class_images_dir
if isinstance(image_transforms, (ListConfig, list)):
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
image_transforms.extend(
[
transforms.ToTensor(),
transforms.Lambda(_rearrange),
]
)
image_transforms = transforms.Compose(image_transforms)
self.image_transforms = image_transforms
self._concept_image_filename_groups = None
self._class_image_filenames = None
def __len__(self):
return len(self.concept_image_filename_groups) * 2
def __getitem__(self, idx):
if idx % 2:
img_group = self._concept_image_filename_groups[int(idx / 2)]
img_filename = random.choice(img_group)
img_path = os.path.join(self.concept_images_dir, img_filename)
txt = self.concept_label
else:
img_path = os.path.join(
self.class_images_dir, self.class_image_filenames[int(idx / 2)]
)
txt = self.class_label
try:
image = Image.open(img_path).convert("RGB")
except RuntimeError as e:
raise RuntimeError(f"Could not read image {img_path}") from e
image = self.image_transforms(image)
data = {"image": image, "txt": txt}
return data
@property
def concept_image_filename_groups(self):
if self._concept_image_filename_groups is None:
self._concept_image_filename_groups = _load_image_filenames_and_alts(
self.concept_images_dir
)
return self._concept_image_filename_groups
@property
def class_image_filenames(self):
if self._class_image_filenames is None:
self._class_image_filenames = _load_image_filenames(self.class_images_dir)
return self._class_image_filenames
@property
def num_records(self):
return len(self)
def _load_image_filenames_and_alts(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
"""Loads images into groups (filenames tagged with `[alt-{n:02d}]` are grouped together)."""
image_filenames = _load_image_filenames(img_dir, image_extensions)
grouped_img_filenames = defaultdict(list)
for filename in image_filenames:
base_filename = re.sub(r"\[alt-\d*\]", "", filename)
grouped_img_filenames[base_filename].append(filename)
return list(grouped_img_filenames.values())
def _load_image_filenames(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
image_filenames = []
for filename in os.listdir(img_dir):
if filename.lower().endswith(image_extensions) and not filename.startswith("."):
image_filenames.append(filename)
random.shuffle(image_filenames)
return image_filenames

View File

@ -0,0 +1,375 @@
"""
Crops to the most interesting part of the image.
MIT License from https://github.com/smartcrop/smartcrop.py/commit/f5377045035abc7ae79d8d9ad40bbc7fce0f6ad7
"""
import math
import sys
import numpy as np
from PIL import Image, ImageDraw
from PIL.ImageFilter import Kernel
def saturation(image):
r, g, b = image.split()
r, g, b = np.array(r), np.array(g), np.array(b)
r, g, b = r.astype(float), g.astype(float), b.astype(float)
maximum = np.maximum(np.maximum(r, g), b) # [0; 255]
minimum = np.minimum(np.minimum(r, g), b) # [0; 255]
s = (maximum + minimum) / 255 # [0.0; 1.0]
d = (maximum - minimum) / 255 # [0.0; 1.0]
d[maximum == minimum] = 0 # if maximum == minimum:
s[maximum == minimum] = 1 # -> saturation = 0 / 1 = 0
mask = s > 1
s[mask] = 2 - d[mask]
return d / s # [0.0; 1.0]
def thirds(x):
"""gets value in the range of [0, 1] where 0 is the center of the pictures
returns weight of rule of thirds [0, 1]."""
x = ((x + 2 / 3) % 2 * 0.5 - 0.5) * 16
return max(1 - x * x, 0)
class SmartCrop:
DEFAULT_SKIN_COLOR = [0.78, 0.57, 0.44]
def __init__(
self,
detail_weight=0.2,
edge_radius=0.4,
edge_weight=-20,
outside_importance=-0.5,
rule_of_thirds=True,
saturation_bias=0.2,
saturation_brightness_max=0.9,
saturation_brightness_min=0.05,
saturation_threshold=0.4,
saturation_weight=0.3,
score_down_sample=8,
skin_bias=0.01,
skin_brightness_max=1,
skin_brightness_min=0.2,
skin_color=None,
skin_threshold=0.8,
skin_weight=1.8,
):
self.detail_weight = detail_weight
self.edge_radius = edge_radius
self.edge_weight = edge_weight
self.outside_importance = outside_importance
self.rule_of_thirds = rule_of_thirds
self.saturation_bias = saturation_bias
self.saturation_brightness_max = saturation_brightness_max
self.saturation_brightness_min = saturation_brightness_min
self.saturation_threshold = saturation_threshold
self.saturation_weight = saturation_weight
self.score_down_sample = score_down_sample
self.skin_bias = skin_bias
self.skin_brightness_max = skin_brightness_max
self.skin_brightness_min = skin_brightness_min
self.skin_color = skin_color or self.DEFAULT_SKIN_COLOR
self.skin_threshold = skin_threshold
self.skin_weight = skin_weight
def analyse(
self,
image,
crop_width,
crop_height,
max_scale=1,
min_scale=0.9,
scale_step=0.1,
step=8,
):
"""
Analyze image and return some suggestions of crops (coordinates).
This implementation / algorithm is really slow for large images.
Use `crop()` which is pre-scaling the image before analyzing it.
"""
cie_image = image.convert("L", (0.2126, 0.7152, 0.0722, 0))
cie_array = np.array(cie_image) # [0; 255]
# R=skin G=edge B=saturation
edge_image = self.detect_edge(cie_image)
skin_image = self.detect_skin(cie_array, image)
saturation_image = self.detect_saturation(cie_array, image)
analyse_image = Image.merge("RGB", [skin_image, edge_image, saturation_image])
del edge_image
del skin_image
del saturation_image
score_image = analyse_image.copy()
score_image.thumbnail(
(
int(math.ceil(image.size[0] / self.score_down_sample)),
int(math.ceil(image.size[1] / self.score_down_sample)),
),
Image.ANTIALIAS,
)
top_crop = None
top_score = -sys.maxsize
crops = self.crops(
image,
crop_width,
crop_height,
max_scale=max_scale,
min_scale=min_scale,
scale_step=scale_step,
step=step,
)
for crop in crops:
crop["score"] = self.score(score_image, crop)
if crop["score"]["total"] > top_score:
top_crop = crop
top_score = crop["score"]["total"]
return {"analyse_image": analyse_image, "crops": crops, "top_crop": top_crop}
def crop(
self,
image,
width,
height,
prescale=True,
max_scale=1,
min_scale=0.9,
scale_step=0.1,
step=8,
):
"""Not yet fully cleaned from https://github.com/hhatto/smartcrop.py."""
scale = min(image.size[0] / width, image.size[1] / height)
crop_width = int(math.floor(width * scale))
crop_height = int(math.floor(height * scale))
# img = 100x100, width = 95x95, scale = 100/95, 1/scale > min
# don't set minscale smaller than 1/scale
# -> don't pick crops that need upscaling
min_scale = min(max_scale, max(1 / scale, min_scale))
prescale_size = 1
if prescale:
prescale_size = 1 / scale / min_scale
if prescale_size < 1:
image = image.copy()
image.thumbnail(
(
int(image.size[0] * prescale_size),
int(image.size[1] * prescale_size),
),
Image.ANTIALIAS,
)
crop_width = int(math.floor(crop_width * prescale_size))
crop_height = int(math.floor(crop_height * prescale_size))
else:
prescale_size = 1
result = self.analyse(
image,
crop_width=crop_width,
crop_height=crop_height,
min_scale=min_scale,
max_scale=max_scale,
scale_step=scale_step,
step=step,
)
for i in range(len(result["crops"])):
crop = result["crops"][i]
crop["x"] = int(math.floor(crop["x"] / prescale_size))
crop["y"] = int(math.floor(crop["y"] / prescale_size))
crop["width"] = int(math.floor(crop["width"] / prescale_size))
crop["height"] = int(math.floor(crop["height"] / prescale_size))
result["crops"][i] = crop
return result
def crops(
self,
image,
crop_width,
crop_height,
max_scale=1,
min_scale=0.9,
scale_step=0.1,
step=8,
):
image_width, image_height = image.size
crops = []
for scale in (
i / 100
for i in range(
int(max_scale * 100),
int((min_scale - scale_step) * 100),
-int(scale_step * 100),
)
):
for y in range(0, image_height, step):
if not (y + crop_height * scale <= image_height):
break
for x in range(0, image_width, step):
if not (x + crop_width * scale <= image_width):
break
crops.append(
{
"x": x,
"y": y,
"width": crop_width * scale,
"height": crop_height * scale,
}
)
if not crops:
raise ValueError(locals())
return crops
def debug_crop(self, analyse_image, crop):
debug_image = analyse_image.copy()
debug_pixels = debug_image.getdata()
debug_crop_image = Image.new(
"RGBA",
(int(math.floor(crop["width"])), int(math.floor(crop["height"]))),
(255, 0, 0, 25),
)
ImageDraw.Draw(debug_crop_image).rectangle(
((0, 0), (crop["width"], crop["height"])), outline=(255, 0, 0)
)
for y in range(analyse_image.size[1]): # height
for x in range(analyse_image.size[0]): # width
p = y * analyse_image.size[0] + x
importance = self.importance(crop, x, y)
if importance > 0:
debug_pixels.putpixel(
(x, y),
(
debug_pixels[p][0],
int(debug_pixels[p][1] + importance * 32),
debug_pixels[p][2],
),
)
elif importance < 0:
debug_pixels.putpixel(
(x, y),
(
int(debug_pixels[p][0] + importance * -64),
debug_pixels[p][1],
debug_pixels[p][2],
),
)
debug_image.paste(
debug_crop_image, (crop["x"], crop["y"]), debug_crop_image.split()[3]
)
return debug_image
def detect_edge(self, cie_image):
return cie_image.filter(Kernel((3, 3), (0, -1, 0, -1, 4, -1, 0, -1, 0), 1, 1))
def detect_saturation(self, cie_array, source_image):
threshold = self.saturation_threshold
saturation_data = saturation(source_image)
mask = (
(saturation_data > threshold)
& (cie_array >= self.saturation_brightness_min * 255)
& (cie_array <= self.saturation_brightness_max * 255)
)
saturation_data[~mask] = 0
saturation_data[mask] = (saturation_data[mask] - threshold) * (
255 / (1 - threshold)
)
return Image.fromarray(saturation_data.astype("uint8"))
def detect_skin(self, cie_array, source_image):
r, g, b = source_image.split()
r, g, b = np.array(r), np.array(g), np.array(b)
r, g, b = r.astype(float), g.astype(float), b.astype(float)
rd = np.ones_like(r) * -self.skin_color[0]
gd = np.ones_like(g) * -self.skin_color[1]
bd = np.ones_like(b) * -self.skin_color[2]
mag = np.sqrt(r * r + g * g + b * b)
mask = ~(abs(mag) < 1e-6)
rd[mask] = r[mask] / mag[mask] - self.skin_color[0]
gd[mask] = g[mask] / mag[mask] - self.skin_color[1]
bd[mask] = b[mask] / mag[mask] - self.skin_color[2]
skin = 1 - np.sqrt(rd * rd + gd * gd + bd * bd)
mask = (
(skin > self.skin_threshold)
& (cie_array >= self.skin_brightness_min * 255)
& (cie_array <= self.skin_brightness_max * 255)
)
skin_data = (skin - self.skin_threshold) * (255 / (1 - self.skin_threshold))
skin_data[~mask] = 0
return Image.fromarray(skin_data.astype("uint8"))
def importance(self, crop, x, y):
if (
crop["x"] > x
or x >= crop["x"] + crop["width"]
or crop["y"] > y
or y >= crop["y"] + crop["height"]
):
return self.outside_importance
x = (x - crop["x"]) / crop["width"]
y = (y - crop["y"]) / crop["height"]
px, py = abs(0.5 - x) * 2, abs(0.5 - y) * 2
# distance from edge
dx = max(px - 1 + self.edge_radius, 0)
dy = max(py - 1 + self.edge_radius, 0)
d = (dx * dx + dy * dy) * self.edge_weight
s = 1.41 - math.sqrt(px * px + py * py)
if self.rule_of_thirds:
s += (max(0, s + d + 0.5) * 1.2) * (thirds(px) + thirds(py))
return s + d
def score(self, target_image, crop):
score = {
"detail": 0,
"saturation": 0,
"skin": 0,
"total": 0,
}
target_data = target_image.getdata()
target_width, target_height = target_image.size
down_sample = self.score_down_sample
inv_down_sample = 1 / down_sample
target_width_down_sample = target_width * down_sample
target_height_down_sample = target_height * down_sample
for y in range(0, target_height_down_sample, down_sample):
for x in range(0, target_width_down_sample, down_sample):
p = int(
math.floor(y * inv_down_sample) * target_width
+ math.floor(x * inv_down_sample)
)
importance = self.importance(crop, x, y)
detail = target_data[p][1] / 255
score["skin"] += (
target_data[p][0] / 255 * (detail + self.skin_bias) * importance
)
score["detail"] += detail * importance
score["saturation"] += (
target_data[p][2]
/ 255
* (detail + self.saturation_bias)
* importance
)
score["total"] = (
score["detail"] * self.detail_weight
+ score["skin"] * self.skin_weight
+ score["saturation"] * self.saturation_weight
) / (crop["width"] * crop["height"])
return score

View File

@ -0,0 +1,12 @@
import logging
from imaginairy import LazyLoadingImage
from imaginairy.enhancers.facecrop import generate_face_crops
from tests import TESTS_FOLDER
logger = logging.getLogger(__name__)
def test_facecrop():
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
generate_face_crops((50, 50, 150, 150), max_width=img.width, max_height=img.height)

View File

@ -60,7 +60,7 @@ def test_clip_masking(filename_base_for_outputs):
result = next(imagine(prompt))
img_path = f"{filename_base_for_outputs}.png"
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1000)
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1100)
boolean_mask_test_cases = [

69
tests/test_roi_utils.py Normal file
View File

@ -0,0 +1,69 @@
import itertools
import random
from imaginairy.roi_utils import (
RoiNotInBoundsError,
resize_roi_coordinates,
square_roi_coordinate,
)
def test_square_roi_coordinate():
img_sizes = (10, 100, 200, 511, 513, 1024)
# iterate through all permutations of image sizes using itertools.product
for img_width, img_height in itertools.product(img_sizes, img_sizes):
# randomly generate a region of interest
for _ in range(100):
x1 = random.randint(0, img_width)
y1 = random.randint(0, img_height)
x2 = random.randint(x1, img_width)
y2 = random.randint(y1, img_height)
roi = x1, y1, x2, y2
try:
x1, y1, x2, y2 = square_roi_coordinate(roi, img_width, img_height)
except RoiNotInBoundsError:
continue
assert (
x2 - x1 == y2 - y1
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
# resize_roi_coordinates
def test_square_resize_roi_coordinates():
img_sizes = (10, 100, 200, 403, 511, 513, 604, 1024)
# iterate through all permutations of image sizes using itertools.product
img_sizes = list(itertools.product(img_sizes, img_sizes))
for img_width, img_height in img_sizes:
# randomly generate a region of interest
rois = []
for _ in range(100):
x1 = random.randint(0 + 1, img_width - 1)
y1 = random.randint(0 + 1, img_height - 1)
x2 = random.randint(x1 + 1, img_width)
y2 = random.randint(y1 + 1, img_height)
roi = x1, y1, x2, y2
rois.append(roi)
rois.append((392, 85, 695, 389))
for roi in rois:
try:
squared_roi = square_roi_coordinate(roi, img_width, img_height)
except RoiNotInBoundsError:
continue
for n in range(10):
factor = 1.25 + 0.3 * n
x1, y1, x2, y2 = resize_roi_coordinates(
squared_roi, factor, img_width, img_height
)
assert (
x2 - x1 == y2 - y1
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
x1, y1, x2, y2 = resize_roi_coordinates(
squared_roi, factor, img_width, img_height, expand_up=False
)
assert (
x2 - x1 == y2 - y1
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"