mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: finetuning
- feature: finetuning your own image models - feature: image prep command. crops to face or other interesting parts of photo - fix: back-compat for hf_hub_download - feature: add prune-ckpt command - feature: allow specification of model config file
This commit is contained in:
parent
4bc78b9be5
commit
5cc73f6087
4
.gitignore
vendored
4
.gitignore
vendored
@ -24,5 +24,5 @@ tests/vastai_cli.py
|
||||
.unison*
|
||||
*.kgrind
|
||||
*.pyprof
|
||||
/imaginairy/enhancers/.polyscope.ini
|
||||
/imaginairy/enhancers/imgui.ini
|
||||
**/.polyscope.ini
|
||||
**/imgui.ini
|
32
README.md
32
README.md
@ -180,6 +180,7 @@ a bowl full of gold bars sitting on a table
|
||||
- Edit images by describing the part you want edited (see example above)
|
||||
- Have AI generate captions for images `aimg describe <filename-or-url>`
|
||||
- Interactive prompt: just run `aimg`
|
||||
- 🎉 finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page
|
||||
|
||||
## How To
|
||||
|
||||
@ -239,6 +240,14 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
|
||||
[Example Colab](https://colab.research.google.com/drive/1rOvQNs0Cmn_yU1bKWjCOHzGVDgZkaTtO?usp=sharing)
|
||||
|
||||
## ChangeLog
|
||||
|
||||
**7.4.0**
|
||||
- feature: 🎉 finetune your own image model. kind of like dreambooth. Read instructions on ["Concept Training"](docs/concept-training.md) page
|
||||
- feature: image prep command. crops to face or other interesting parts of photo
|
||||
- fix: back-compat for hf_hub_download
|
||||
- feature: add prune-ckpt command
|
||||
- feature: allow specification of model config file
|
||||
|
||||
**7.3.0**
|
||||
- feature: 🎉 depth-based image-to-image generations (and inpainting)
|
||||
- fix: k_euler_a produces more consistent images per seed (randomization respects the seed again)
|
||||
@ -390,7 +399,6 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
|
||||
## Not Supported
|
||||
- a GUI. this is a python library
|
||||
- training
|
||||
- exploratory features that don't work well
|
||||
|
||||
## Todo
|
||||
@ -403,7 +411,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
- https://github.com/CompVis/stable-diffusion/pull/177
|
||||
- https://github.com/huggingface/diffusers/pull/532/files
|
||||
- https://github.com/HazyResearch/flash-attention
|
||||
- xformers improvements https://www.photoroom.com/tech/stable-diffusion-100-percent-faster-with-memory-efficient-attention/
|
||||
- ✅ xformers improvements https://www.photoroom.com/tech/stable-diffusion-100-percent-faster-with-memory-efficient-attention/
|
||||
- Development Environment
|
||||
- ✅ add tests
|
||||
- ✅ set up ci (test/lint/format)
|
||||
@ -424,9 +432,10 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
- Compositional Visual Generation
|
||||
- https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
|
||||
- https://colab.research.google.com/github/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch/blob/main/notebooks/demo.ipynb#scrollTo=wt_j3uXZGFAS
|
||||
- negative prompting
|
||||
- ✅ negative prompting
|
||||
- some syntax to allow it in a text string
|
||||
- images as actual prompts instead of just init images. is this the same as textual inversion?
|
||||
- 🚫 images as actual prompts instead of just init images.
|
||||
- not directly possible due to model architecture.
|
||||
- requires model fine-tuning since SD1.4 expects 77x768 text encoding input
|
||||
- https://twitter.com/Buntworthy/status/1566744186153484288
|
||||
- https://github.com/justinpinkney/stable-diffusion
|
||||
@ -439,7 +448,7 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
- ✅ inpainting
|
||||
- https://github.com/Jack000/glid-3-xl-stable
|
||||
- https://github.com/andreas128/RePaint
|
||||
- img2img but keeps img stable
|
||||
- ✅ img2img but keeps img stable
|
||||
- https://www.reddit.com/r/StableDiffusion/comments/xboy90/a_better_way_of_doing_img2img_by_finding_the/
|
||||
- https://gist.github.com/trygvebw/c71334dd127d537a15e9d59790f7f5e1
|
||||
- https://github.com/pesser/stable-diffusion/commit/bbb52981460707963e2a62160890d7ecbce00e79
|
||||
@ -478,6 +487,11 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
- Other
|
||||
- Enhancement pipelines
|
||||
- text-to-3d https://dreamfusionpaper.github.io/
|
||||
- https://shihmengli.github.io/3D-Photo-Inpainting/
|
||||
- https://github.com/thygate/stable-diffusion-webui-depthmap-script/discussions/50
|
||||
- Depth estimation
|
||||
- what is SOTA for monocular depth estimation?
|
||||
- https://github.com/compphoto/BoostingMonocularDepth
|
||||
- make a video https://github.com/lucidrains/make-a-video-pytorch
|
||||
- animations
|
||||
- https://github.com/francislabountyjr/stable-diffusion/blob/main/inferencing_notebook.ipynb
|
||||
@ -499,6 +513,14 @@ would be uncorrelated to the rest of the surrounding image. It created terrible
|
||||
- ✅ deploy to pypi
|
||||
- find similar images https://knn5.laion.ai/?back=https%3A%2F%2Fknn5.laion.ai%2F&index=laion5B&useMclip=false
|
||||
- https://github.com/vicgalle/stable-diffusion-aesthetic-gradients
|
||||
- Training
|
||||
- Finetuning "dreambooth" style
|
||||
- Textual Inversion
|
||||
- Performance Improvements
|
||||
- [ColoassalAI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) - almost got it working but it's not easy enough to install to merit inclusion in imaginairy. We should check back in on this.
|
||||
- Xformers
|
||||
- Deepspeed
|
||||
-
|
||||
|
||||
## Notable Stable Diffusion Implementations
|
||||
- https://github.com/ahrm/UnstableFusion
|
||||
|
BIN
data/DejaVuSans.ttf
Normal file
BIN
data/DejaVuSans.ttf
Normal file
Binary file not shown.
85
docs/concept-training.md
Normal file
85
docs/concept-training.md
Normal file
@ -0,0 +1,85 @@
|
||||
# Adding a concept to Stable Diffusion
|
||||
|
||||
You can use Imaginairy to teach the model a new concept (a person, thing, style, etc) using the `aimg train-concept`
|
||||
command.
|
||||
|
||||
## Requirements
|
||||
- Graphics card: 3090 or better
|
||||
- Linux
|
||||
- A working Imaginairy installation
|
||||
- a folder of images of the concept you want to teach the model
|
||||
|
||||
## Background
|
||||
|
||||
To train the model we show it a lot of images of the concept we want to teach it. The problem is the model can easily
|
||||
overfit to the images we show it. To prevent this we also show it images of the class of thing that is being trained.
|
||||
Imaginairy will generate the images needed for this before running the training job.
|
||||
|
||||
Provided a directory of concept images, a concept token, and a class token, this command will train the model
|
||||
to generate images of that concept.
|
||||
|
||||
|
||||
This happens in a 3-step process:
|
||||
|
||||
1. Cropping and resizing your training images. If --person is set we crop to include the face.
|
||||
2. Generating a set of class images to train on. This helps prevent overfitting.
|
||||
3. Training the model on the concept and class images.
|
||||
|
||||
The output of this command is a new model weights file that you can use with the --model option.
|
||||
|
||||
|
||||
|
||||
## Instructions
|
||||
|
||||
1. Gather a set of images of the concept you want to train on. The images should show the subject from a variety of angles
|
||||
and in a variety of situations.
|
||||
2. Run `aimg train-concept` to train the model.
|
||||
|
||||
- Concept label: For a person, firstnamelastname should be fine.
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
|
||||
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
|
||||
or "woman".
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
|
||||
- CLass images will be generated for you if you do not provide them.
|
||||
|
||||
For example, if you were training on photos of a man named bill hamilton you could run the following:
|
||||
|
||||
```
|
||||
aimg train-concept \\
|
||||
--person \\
|
||||
--concept-label "photo of billhamilton man" \\
|
||||
--concept-images-dir ./images/billhamilton \\
|
||||
--class-label "photo of a man" \\
|
||||
--class-images-dir ./images/man
|
||||
```
|
||||
3. Stop training before it overfits.
|
||||
- The training script will output checkpoint ckpt files into the logs folder of wherever it is run from. You can also
|
||||
monitor generated images in the logs/images folder. They will be the ones named "sample"
|
||||
- I don't have great advice on when to stop training yet. I stopped mine at epoch 62 at it didn't seem quite good enough, at epoch 111 it
|
||||
produced my face correctly 50% of the time but also seemed overfit in some ways (always placing me in the same clothes or background as training photos).
|
||||
- You can monitor model training progress in Tensorboard. Run `tensorboard --logdir lightning_logs` and open the link it gives you in your browser.
|
||||
|
||||
4. Prune the model to bring the size from 11gb to ~4gb: `aimg prune-ckpt logs/2023-01-15T05-52-06/checkpoints/epoch\=000049.ckpt`. Copy it somewhere
|
||||
and give it a meaninful name.
|
||||
|
||||
## Using the new model
|
||||
You can reference the model like this in imaginairy:
|
||||
`imagine --model my-models/billhamilton-man-e111.ckpt`
|
||||
|
||||
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
|
||||
|
||||
|
||||
## Disclaimers
|
||||
|
||||
- The settings imaginairy uses to train the model are different than other software projects. As such you cannot follow
|
||||
advice you may read from other tutorials regarding learning rate, epochs, steps, batch size. They are not directly
|
||||
comparable. In laymans terms the "steps" are much bigger in imaginairy.
|
||||
- I consider this training feature experimental and don't currently plan to offer support for it. Any further work will
|
||||
be at my leisure. As a result I may close any reported issues related to this feature.
|
||||
- You can find a lot more relevant information here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
|
||||
|
||||
## Todo
|
||||
- figure out how to improve consistency of quality from trained model
|
||||
- train on the depth guided model instead of SD 1.5 since that will enable more consistent output
|
||||
- figure out metric to use for stopping training
|
||||
- possibly swap out and randomize backgrounds on training photos so over-fitting does not occur
|
@ -134,6 +134,7 @@ def imagine(
|
||||
)
|
||||
model = get_diffusion_model(
|
||||
weights_location=prompt.model,
|
||||
config_path=prompt.model_config_path,
|
||||
half_mode=half_mode,
|
||||
for_inpainting=prompt.mask_image or prompt.mask_prompt,
|
||||
)
|
||||
@ -182,6 +183,7 @@ def imagine(
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
)
|
||||
|
||||
except PIL.UnidentifiedImageError:
|
||||
logger.warning(f"Could not load image: {prompt.init_image}")
|
||||
continue
|
||||
@ -194,9 +196,16 @@ def imagine(
|
||||
)
|
||||
elif prompt.mask_image:
|
||||
mask_image = prompt.mask_image.convert("L")
|
||||
mask_image = pillow_fit_image_within(
|
||||
mask_image,
|
||||
max_height=prompt.height,
|
||||
max_width=prompt.width,
|
||||
convert="L",
|
||||
)
|
||||
|
||||
if mask_image is not None:
|
||||
log_img(mask_image, "init mask")
|
||||
|
||||
if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE:
|
||||
mask_image = ImageOps.invert(mask_image)
|
||||
|
||||
@ -355,6 +364,8 @@ def imagine(
|
||||
shape=shape,
|
||||
batch_size=1,
|
||||
)
|
||||
# from torch.nn.functional import interpolate
|
||||
# samples = interpolate(samples, scale_factor=2, mode='nearest')
|
||||
|
||||
x_samples = model.decode_first_stage(samples)
|
||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import math
|
||||
import os.path
|
||||
|
||||
import click
|
||||
from click_shell import shell
|
||||
@ -11,6 +12,13 @@ from imaginairy.enhancers.prompt_expansion import expand_prompts
|
||||
from imaginairy.log_utils import configure_logging
|
||||
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
|
||||
from imaginairy.schema import ImaginePrompt
|
||||
from imaginairy.train import train_diffusion_model
|
||||
from imaginairy.training_tools.image_prep import (
|
||||
create_class_images,
|
||||
get_image_filenames,
|
||||
prep_images,
|
||||
)
|
||||
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -183,6 +191,12 @@ logger = logging.getLogger(__name__)
|
||||
show_default=True,
|
||||
default=config.DEFAULT_MODEL,
|
||||
)
|
||||
@click.option(
|
||||
"--model-config-path",
|
||||
help="Model config file to use. If a model name is specified, the appropriate config will be used.",
|
||||
show_default=True,
|
||||
default=config.DEFAULT_MODEL,
|
||||
)
|
||||
@click.option(
|
||||
"--prompt-library-path",
|
||||
help="Path to folder containing phrase lists in txt files. Use txt filename in prompt: {_filename_}.",
|
||||
@ -221,6 +235,7 @@ def imagine_cmd(
|
||||
caption,
|
||||
precision,
|
||||
model_weights_path,
|
||||
model_config_path,
|
||||
prompt_library_path,
|
||||
):
|
||||
"""Have the AI generate images. alias:imagine."""
|
||||
@ -282,6 +297,7 @@ def imagine_cmd(
|
||||
fix_faces_fidelity=fix_faces_fidelity,
|
||||
tile_mode=_tile_mode,
|
||||
model=model_weights_path,
|
||||
model_config_path=model_config_path,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
@ -315,6 +331,240 @@ def describe(image_filepaths):
|
||||
print(generate_caption(img.copy()))
|
||||
|
||||
|
||||
@click.option(
|
||||
"--concept-label",
|
||||
help=(
|
||||
'The concept you are training on. Usually "a photo of [person or thing] [classname]" is what you should use.'
|
||||
),
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--concept-images-dir",
|
||||
type=click.Path(),
|
||||
required=True,
|
||||
help="Where to find the pre-processed concept images to train on.",
|
||||
)
|
||||
@click.option(
|
||||
"--class-label",
|
||||
help=(
|
||||
'What class of things does the concept belong to. For example, if you are training on "a painting of a George Washington", '
|
||||
'you might use "a painting of a man" as the class label. We use this to prevent the model from overfitting.'
|
||||
),
|
||||
default="a photo of *",
|
||||
)
|
||||
@click.option(
|
||||
"--class-images-dir",
|
||||
type=click.Path(),
|
||||
required=True,
|
||||
help="Where to find the pre-processed class images to train on.",
|
||||
)
|
||||
@click.option(
|
||||
"--n-class-images",
|
||||
type=int,
|
||||
default=300,
|
||||
help="Number of class images to generate.",
|
||||
)
|
||||
@click.option(
|
||||
"--model-weights-path",
|
||||
"--model",
|
||||
"model",
|
||||
help=f"Model to use. Should be one of {', '.join(MODEL_SHORT_NAMES)}, or a path to custom weights.",
|
||||
show_default=True,
|
||||
default=config.DEFAULT_MODEL,
|
||||
)
|
||||
@click.option(
|
||||
"--person",
|
||||
"is_person",
|
||||
is_flag=True,
|
||||
help="Set if images are of a person. Will use face detection and enhancement.",
|
||||
)
|
||||
@click.option(
|
||||
"-y",
|
||||
"preconfirmed",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Bypass input confirmations.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip-prep",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Skip the image preparation step.",
|
||||
)
|
||||
@click.option(
|
||||
"--skip-class-img-gen",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Skip the class image generation step.",
|
||||
)
|
||||
@aimg.command()
|
||||
def train_concept(
|
||||
concept_label,
|
||||
concept_images_dir,
|
||||
class_label,
|
||||
class_images_dir,
|
||||
n_class_images,
|
||||
model,
|
||||
is_person,
|
||||
preconfirmed,
|
||||
skip_prep,
|
||||
skip_class_img_gen,
|
||||
):
|
||||
"""
|
||||
Teach the model a new concept (a person, thing, style, etc).
|
||||
|
||||
Provided a directory of concept images, a concept token, and a class token, this command will train the model
|
||||
to generate images of that concept.
|
||||
|
||||
\b
|
||||
This happens in a 3-step process:
|
||||
1. Cropping and resizing your training images. If --person is set we crop to include the face.
|
||||
2. Generating a set of class images to train on. This helps prevent overfitting.
|
||||
3. Training the model on the concept and class images.
|
||||
|
||||
The output of this command is a new model weights file that you can use with the --model option.
|
||||
|
||||
\b
|
||||
## Instructions
|
||||
1. Gather a set of images of the concept you want to train on. They should show the subject from a variety of angles
|
||||
and in a variety of situations.
|
||||
2. Train the model.
|
||||
- Concept label: For a person, firstnamelastname should be fine.
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the concept label.
|
||||
- Class label: This is the category of the things beings trained on. For people this is typically "person", "man"
|
||||
or "woman".
|
||||
- If all the training images are photos you should add "a photo of" to the beginning of the class label.
|
||||
- CLass images will be generated for you if you do not provide them.
|
||||
3. Stop training before it overfits. I haven't figured this out yet.
|
||||
|
||||
|
||||
For example, if you were training on photos of a man named bill hamilton you could run the following:
|
||||
|
||||
\b
|
||||
aimg train-concept \\
|
||||
--person \\
|
||||
--concept-label "photo of billhamilton man" \\
|
||||
--concept-images-dir ./images/billhamilton \\
|
||||
--class-label "photo of a man" \\
|
||||
--class-images-dir ./images/man
|
||||
|
||||
When you use the model you should prompt with `firstnamelastname classname` (e.g. `billhamilton man`).
|
||||
|
||||
You can find a lot of relevant instructions here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
|
||||
"""
|
||||
configure_logging()
|
||||
target_size = 512
|
||||
# Step 1. Crop and enhance the training images
|
||||
prepped_images_path = os.path.join(concept_images_dir, "prepped-images")
|
||||
image_filenames = get_image_filenames(concept_images_dir)
|
||||
click.secho(
|
||||
f'\n🤖🧠 Training "{concept_label}" based on {len(image_filenames)} images.\n'
|
||||
)
|
||||
|
||||
if not skip_prep:
|
||||
msg = (
|
||||
f"Creating cropped copies of the {len(image_filenames)} concept images\n"
|
||||
f" Is Person: {is_person}\n"
|
||||
f" Source: {concept_images_dir}\n"
|
||||
f" Dest: {prepped_images_path}\n"
|
||||
)
|
||||
logger.info(msg)
|
||||
if not is_person:
|
||||
click.secho("⚠️ the --person flag was not set. ", fg="yellow")
|
||||
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
|
||||
prep_images(
|
||||
images_dir=concept_images_dir, is_person=is_person, target_size=target_size
|
||||
)
|
||||
concept_images_dir = prepped_images_path
|
||||
|
||||
if not skip_class_img_gen:
|
||||
# Step 2. Generate class images
|
||||
class_image_filenames = get_image_filenames(class_images_dir)
|
||||
images_needed = max(n_class_images - len(class_image_filenames), 0)
|
||||
logger.info(f"Generating {n_class_images} class images in {class_images_dir}")
|
||||
logger.info(
|
||||
f"{len(class_image_filenames)} existing class images found so only generating {images_needed}."
|
||||
)
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
create_class_images(
|
||||
class_description=class_label,
|
||||
output_folder=class_images_dir,
|
||||
num_images=n_class_images,
|
||||
)
|
||||
|
||||
logger.info("Training the model...")
|
||||
if not preconfirmed and not click.confirm("Continue?"):
|
||||
return
|
||||
|
||||
# Step 3. Train the model
|
||||
train_diffusion_model(
|
||||
concept_label=concept_label,
|
||||
concept_images_dir=concept_images_dir,
|
||||
class_label=class_label,
|
||||
class_images_dir=class_images_dir,
|
||||
weights_location=model,
|
||||
logdir="logs",
|
||||
learning_rate=1e-6,
|
||||
accumulate_grad_batches=32,
|
||||
)
|
||||
|
||||
|
||||
@click.argument(
|
||||
"images_dir",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--person",
|
||||
"is_person",
|
||||
is_flag=True,
|
||||
help="Set if images are of a person. Will use face detection and enhancement.",
|
||||
)
|
||||
@click.option(
|
||||
"--target-size",
|
||||
default=512,
|
||||
type=int,
|
||||
show_default=True,
|
||||
)
|
||||
@aimg.command("prep-images")
|
||||
def prepare_images(images_dir, is_person, target_size):
|
||||
"""
|
||||
Prepare a folder of images for training.
|
||||
|
||||
Prepped images will be written to the `prepped-images` subfolder.
|
||||
|
||||
All images will be cropped and resized to (default) 512x512.
|
||||
Upscaling and face enhancement will be applied as needed to smaller images.
|
||||
|
||||
Examples:
|
||||
aimg prep-images --person ./images/selfies
|
||||
aimg prep-images ./images/toy-train
|
||||
"""
|
||||
configure_logging()
|
||||
prep_images(images_dir=images_dir, is_person=is_person, target_size=target_size)
|
||||
|
||||
|
||||
@click.argument("ckpt_paths", nargs=-1)
|
||||
@aimg.command("prune-ckpt")
|
||||
def prune_ckpt(ckpt_paths):
|
||||
"""
|
||||
Prune a checkpoint file.
|
||||
|
||||
This will remove the optimizer state from the checkpoint file.
|
||||
This is useful if you want to use the checkpoint file for inference and save a lot of disk space
|
||||
|
||||
Example:
|
||||
aimg prune-ckpt ./path/to/checkpoint.ckpt
|
||||
"""
|
||||
click.secho("Pruning checkpoint files...")
|
||||
configure_logging()
|
||||
for p in ckpt_paths:
|
||||
prune_diffusion_ckpt(p)
|
||||
|
||||
|
||||
aimg.add_command(imagine_cmd, name="imagine")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -12,6 +12,8 @@ DEFAULT_NEGATIVE_PROMPT = (
|
||||
"grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art,"
|
||||
)
|
||||
|
||||
SPLITMEM_ENABLED = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
@ -19,6 +21,7 @@ class ModelConfig:
|
||||
config_path: str
|
||||
weights_url: str
|
||||
default_image_size: int
|
||||
weights_url_full: str = None
|
||||
forced_attn_precision: str = "default"
|
||||
|
||||
|
||||
@ -34,7 +37,8 @@ MODEL_CONFIGS = [
|
||||
ModelConfig(
|
||||
short_name="SD-1.5",
|
||||
config_path="configs/stable-diffusion-v1.yaml",
|
||||
weights_url="https://huggingface.co/acheong08/SD-V1-5-cloned/resolve/fc392f6bd4345b80fc2256fa8aded8766b6c629e/v1-5-pruned-emaonly.ckpt",
|
||||
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt",
|
||||
weights_url_full="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned.ckpt",
|
||||
default_image_size=512,
|
||||
),
|
||||
ModelConfig(
|
||||
|
@ -7,7 +7,7 @@ model:
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
first_stage_key: "image"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
@ -20,7 +20,7 @@ model:
|
||||
scheduler_config: # 10000 warm-up steps
|
||||
target: imaginairy.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
|
@ -56,6 +56,7 @@ def enhance_faces(img, fidelity=0):
|
||||
net = codeformer_model()
|
||||
|
||||
face_helper = face_restore_helper()
|
||||
face_helper.clean_all()
|
||||
|
||||
image = img.convert("RGB")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
|
57
imaginairy/enhancers/facecrop.py
Normal file
57
imaginairy/enhancers/facecrop.py
Normal file
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
|
||||
from imaginairy.enhancers.face_restoration_codeformer import face_restore_helper
|
||||
from imaginairy.roi_utils import resize_roi_coordinates, square_roi_coordinate
|
||||
|
||||
|
||||
def detect_faces(img):
|
||||
face_helper = face_restore_helper()
|
||||
face_helper.clean_all()
|
||||
|
||||
image = img.convert("RGB")
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
# rotate to BGR
|
||||
np_img = np_img[:, :, ::-1]
|
||||
|
||||
face_helper.read_image(np_img)
|
||||
|
||||
face_helper.get_face_landmarks_5(
|
||||
only_center_face=False, resize=640, eye_dist_threshold=5
|
||||
)
|
||||
face_helper.align_warp_face()
|
||||
faceboxes = []
|
||||
|
||||
for x1, y1, x2, y2, scaling in face_helper.det_faces:
|
||||
# x1, y1, x2, y2 = x1 * scaling, y1 * scaling, x2 * scaling, y2 * scaling
|
||||
faceboxes.append((x1, y1, x2, y2))
|
||||
|
||||
return faceboxes
|
||||
|
||||
|
||||
def generate_face_crops(face_roi, max_width, max_height):
|
||||
"""Returns bounding boxes at various zoom levels for faces in the image."""
|
||||
|
||||
crops = []
|
||||
squared_roi = square_roi_coordinate(face_roi, max_width, max_height)
|
||||
|
||||
crops.append(resize_roi_coordinates(squared_roi, 1.1, max_width, max_height))
|
||||
# 1.6 generally enough to capture entire face
|
||||
base_expanded_roi = resize_roi_coordinates(squared_roi, 1.6, max_width, max_height)
|
||||
|
||||
crops.append(base_expanded_roi)
|
||||
current_width = base_expanded_roi[2] - base_expanded_roi[0]
|
||||
|
||||
# some zoomed out variations
|
||||
for n in range(2):
|
||||
factor = 1.25 + 0.4 * n
|
||||
|
||||
expanded_roi = resize_roi_coordinates(
|
||||
base_expanded_roi, factor, max_width, max_height, expand_up=False
|
||||
)
|
||||
new_width = expanded_roi[2] - expanded_roi[0]
|
||||
if new_width <= current_width * 1.1:
|
||||
# if the zoomed out size isn't suffienctly larger (because there is nowhere to zoom out to), stop
|
||||
break
|
||||
crops.append(expanded_roi)
|
||||
current_width = new_width
|
||||
return crops
|
@ -151,3 +151,34 @@ def get_random_non_repeating_combination( # noqa
|
||||
idx = idx // len(sequence)
|
||||
yield values
|
||||
n -= sub_n
|
||||
|
||||
|
||||
# future use
|
||||
prompt_templates = [
|
||||
# https://www.reddit.com/r/StableDiffusion/comments/ya4zxm/dreambooth_is_crazy_prompts_workflow_in_comments/
|
||||
"cinematic still of #prompt-token# as rugged warrior, threatening xenomorph, alien movie (1986),ultrarealistic",
|
||||
"colorful cinematic still of #prompt-token#, armor, cyberpunk,background made of brain cells, back light, organic, art by greg rutkowski, ultrarealistic, leica 30mm",
|
||||
"colorful cinematic still of #prompt-token#, armor, cyberpunk, with a xenonorph, in alien movie (1986),background made of brain cells, organic, ultrarealistic, leic 30mm",
|
||||
"colorful cinematic still of #prompt-token#, #prompt-token# with long hair, color lights, on stage, ultrarealistic",
|
||||
"colorful portrait of #prompt-token# with dark glasses as eminem, gold chain necklace, relfective puffer jacket, short white hair, in front of music shop,ultrarealistic, leica 30mm",
|
||||
"colorful photo of #prompt-token# as kurt cobain with glasses, on stage, lights, ultrarealistic, leica 30mm",
|
||||
"impressionist painting of ((#prompt-token#)) by Daniel F Gerhartz, ((#prompt-token# painted in an impressionist style)), nature, trees",
|
||||
"pencil sketch of #prompt-token#, #prompt-token#, #prompt-token#, inspired by greg rutkowski, digital art by artgem",
|
||||
"photo, colorful cinematic still of #prompt-token#, organic armor,cyberpunk,background brain cells mesh, art by greg rutkowski",
|
||||
"photo, colorful cinematic still of #prompt-token# with organic armor, cyberpunk background, #prompt-token#, greg rutkowski",
|
||||
"photo of #prompt-token# astronaut, astronaut, glasses, helmet in alien world abstract oil painting, greg rutkowski, detailed face",
|
||||
"photo of #prompt-token# as firefighter, helmet, ultrarealistic, leica 30mm",
|
||||
"photo of #prompt-token#, bowler hat, in django unchained movie, ultrarealistic, leica 30mm",
|
||||
"photo of #prompt-token# as serious spiderman with glasses, ultrarealistic, leica 30mm",
|
||||
"photo of #prompt-token# as steampunk warrior, neon organic vines, glasses, digital painting",
|
||||
"photo of #prompt-token# as supermario with glassesm mustach, blue overall, red short,#prompt-token#,#prompt-token#. ultrarealistic, leica 30mm",
|
||||
"photo of #prompt-token# as targaryen warrior with glasses, long white hair, armor, ultrarealistic, leica 30mm",
|
||||
"portrait of #prompt-token# as knight, with glasses white eyes, white mid hair, scar on face, handsome, elegant, intricate, headshot, highly detailed, digital",
|
||||
"portrait of #prompt-token# as hulk, handsome, elegant, intricate luminescent cyberpunk background, headshot, highly detailed, digital painting",
|
||||
"portrait of #prompt-token# as private eye detective, intricate, war torn, highly detailed, digital painting, concept art, smooth, sharp focus",
|
||||
# https://publicprompts.art/
|
||||
"Retro comic style artwork, highly detailed #prompt-token#, comic book cover, symmetrical, vibrant",
|
||||
"Closeup face portrait of #prompt-token# wearing crown, smooth soft skin, big dreamy eyes, beautiful intricate colored hair, symmetrical, anime wide eyes, soft lighting, detailed face, by makoto shinkai, stanley artgerm lau, wlop, rossdraws, concept art, digital painting, looking into camera"
|
||||
"highly detailed portrait brycedrennan man in gta v, unreal engine, fantasy art by greg rutkowski, loish, rhads, ferdinand knab, makoto shinkai and lois van baarle, ilya kuvshinov, rossdraws, tom bagshaw, global illumination, radiant light, detailed and intricate environment "
|
||||
"brycedrennan man: a highly detailed uncropped full-color epic corporate portrait headshot photograph. best portfolio photoraphy photo winner, meticulous detail, hyperrealistic, centered uncropped symmetrical beautiful masculine facial features, atmospheric, photorealistic texture, canon 5D mark III photo, professional studio lighting, aesthetic, very inspirational, motivational. ByKaren L Richard Photography, Photoweb, Splento, Americanoize, Lemonlight",
|
||||
]
|
||||
|
@ -9,8 +9,10 @@ from PIL import Image
|
||||
from imaginairy.utils import get_device
|
||||
|
||||
|
||||
def pillow_fit_image_within(image: PIL.Image.Image, max_height=512, max_width=512):
|
||||
image = image.convert("RGB")
|
||||
def pillow_fit_image_within(
|
||||
image: PIL.Image.Image, max_height=512, max_width=512, convert="RGB"
|
||||
):
|
||||
image = image.convert(convert)
|
||||
w, h = image.size
|
||||
resize_ratio = 1
|
||||
if w > max_width or h > max_height:
|
||||
|
133
imaginairy/lr_scheduler.py
Normal file
133
imaginairy/lr_scheduler.py
Normal file
@ -0,0 +1,133 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
warm_up_steps,
|
||||
lr_min,
|
||||
lr_max,
|
||||
lr_start,
|
||||
max_decay_steps,
|
||||
verbosity_interval=0,
|
||||
):
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.lr_start = lr_start
|
||||
self.lr_min = lr_min
|
||||
self.lr_max = lr_max
|
||||
self.lr_max_decay_steps = max_decay_steps
|
||||
self.last_lr = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
) / self.lr_warm_up_steps * n + self.lr_start
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
t = (n - self.lr_warm_up_steps) / (
|
||||
self.lr_max_decay_steps - self.lr_warm_up_steps
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi))
|
||||
self.last_lr = lr
|
||||
return lr
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler2:
|
||||
"""
|
||||
supports repeated iterations, configurable via lists
|
||||
note: use with a base_lr of 1.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
|
||||
):
|
||||
assert (
|
||||
len(warm_up_steps)
|
||||
== len(f_min)
|
||||
== len(f_max)
|
||||
== len(f_start)
|
||||
== len(cycle_lengths)
|
||||
)
|
||||
self.lr_warm_up_steps = warm_up_steps
|
||||
self.f_start = f_start
|
||||
self.f_min = f_min
|
||||
self.f_max = f_max
|
||||
self.cycle_lengths = cycle_lengths
|
||||
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||
self.last_f = 0.0
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def find_in_interval(self, n):
|
||||
interval = 0
|
||||
for cl in self.cum_cycles[1:]:
|
||||
if n <= cl:
|
||||
return interval
|
||||
interval += 1
|
||||
return None
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
t = (n - self.lr_warm_up_steps[cycle]) / (
|
||||
self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
|
||||
)
|
||||
t = min(t, 1.0)
|
||||
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
1 + np.cos(t * np.pi)
|
||||
)
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
def __call__(self, n, **kwargs):
|
||||
return self.schedule(n, **kwargs)
|
||||
|
||||
|
||||
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
] * n + self.f_start[cycle]
|
||||
self.last_f = f
|
||||
return f
|
||||
|
||||
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||
self.cycle_lengths[cycle] - n
|
||||
) / (self.cycle_lengths[cycle])
|
||||
self.last_f = f
|
||||
return f
|
@ -1,12 +1,15 @@
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import urllib.parse
|
||||
from functools import wraps
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, try_to_load_from_cache
|
||||
from huggingface_hub import hf_hub_download as _hf_hub_download
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from omegaconf import OmegaConf
|
||||
from transformers.utils.hub import HfFolder
|
||||
|
||||
@ -30,11 +33,12 @@ class HuggingFaceAuthorizationError(RuntimeError):
|
||||
class MemoryAwareModel:
|
||||
"""Wraps a model to allow dynamic loading/unloading as needed."""
|
||||
|
||||
def __init__(self, config_path, weights_path, half_mode=None):
|
||||
def __init__(self, config_path, weights_path, half_mode=None, for_training=False):
|
||||
self._config_path = config_path
|
||||
self._weights_path = weights_path
|
||||
self._half_mode = half_mode
|
||||
self._model = None
|
||||
self._for_training = for_training
|
||||
|
||||
LOADED_MODELS[(self._config_path, self._weights_path)] = self
|
||||
|
||||
@ -47,9 +51,13 @@ class MemoryAwareModel:
|
||||
# unload all models in LOADED_MODELS
|
||||
for model in LOADED_MODELS.values():
|
||||
model.unload_model()
|
||||
model_config = OmegaConf.load(f"{PKG_ROOT}/{self._config_path}")
|
||||
if self._for_training:
|
||||
model_config.use_ema = True
|
||||
# model_config.use_scheduler = True
|
||||
|
||||
model = load_model_from_config(
|
||||
config=OmegaConf.load(f"{PKG_ROOT}/{self._config_path}"),
|
||||
config=model_config,
|
||||
weights_location=self._weights_path,
|
||||
)
|
||||
|
||||
@ -75,7 +83,7 @@ class MemoryAwareModel:
|
||||
|
||||
def load_model_from_config(config, weights_location):
|
||||
if weights_location.startswith("http"):
|
||||
ckpt_path = get_cached_url_path(weights_location)
|
||||
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
||||
else:
|
||||
ckpt_path = weights_location
|
||||
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
||||
@ -94,7 +102,7 @@ def load_model_from_config(config, weights_location):
|
||||
if weights_location.startswith("http"):
|
||||
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
|
||||
os.remove(ckpt_path)
|
||||
ckpt_path = get_cached_url_path(weights_location)
|
||||
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
||||
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
||||
if pl_sd is None:
|
||||
raise e
|
||||
@ -119,6 +127,7 @@ def get_diffusion_model(
|
||||
config_path="configs/stable-diffusion-v1.yaml",
|
||||
half_mode=None,
|
||||
for_inpainting=False,
|
||||
for_training=False,
|
||||
):
|
||||
"""
|
||||
Load a diffusion model.
|
||||
@ -127,7 +136,11 @@ def get_diffusion_model(
|
||||
"""
|
||||
try:
|
||||
return _get_diffusion_model(
|
||||
weights_location, config_path, half_mode, for_inpainting
|
||||
weights_location,
|
||||
config_path,
|
||||
half_mode,
|
||||
for_inpainting,
|
||||
for_training=for_training,
|
||||
)
|
||||
except HuggingFaceAuthorizationError as e:
|
||||
if for_inpainting:
|
||||
@ -135,7 +148,11 @@ def get_diffusion_model(
|
||||
f"Failed to load inpainting model. Attempting to fall-back to standard model. {str(e)}"
|
||||
)
|
||||
return _get_diffusion_model(
|
||||
iconfig.DEFAULT_MODEL, config_path, half_mode, for_inpainting=False
|
||||
iconfig.DEFAULT_MODEL,
|
||||
config_path,
|
||||
half_mode,
|
||||
for_inpainting=False,
|
||||
for_training=for_training,
|
||||
)
|
||||
raise e
|
||||
|
||||
@ -145,6 +162,7 @@ def _get_diffusion_model(
|
||||
config_path="configs/stable-diffusion-v1.yaml",
|
||||
half_mode=None,
|
||||
for_inpainting=False,
|
||||
for_training=False,
|
||||
):
|
||||
"""
|
||||
Load a diffusion model.
|
||||
@ -152,25 +170,12 @@ def _get_diffusion_model(
|
||||
Weights location may also be shortcut name, e.g. "SD-1.5"
|
||||
"""
|
||||
global MOST_RECENTLY_LOADED_MODEL # noqa
|
||||
model_config = None
|
||||
if weights_location is None:
|
||||
weights_location = iconfig.DEFAULT_MODEL
|
||||
if (
|
||||
for_inpainting
|
||||
and f"{weights_location}-inpaint" in iconfig.MODEL_CONFIG_SHORTCUTS
|
||||
):
|
||||
model_config = iconfig.MODEL_CONFIG_SHORTCUTS[f"{weights_location}-inpaint"]
|
||||
config_path, weights_location = (
|
||||
model_config.config_path,
|
||||
model_config.weights_url,
|
||||
)
|
||||
elif weights_location in iconfig.MODEL_CONFIG_SHORTCUTS:
|
||||
model_config = iconfig.MODEL_CONFIG_SHORTCUTS[weights_location]
|
||||
config_path, weights_location = (
|
||||
model_config.config_path,
|
||||
model_config.weights_url,
|
||||
)
|
||||
|
||||
model_config, weights_location, config_path = resolve_model_paths(
|
||||
weights_path=weights_location,
|
||||
config_path=config_path,
|
||||
for_inpainting=for_inpainting,
|
||||
for_training=for_training,
|
||||
)
|
||||
# some models need the attention calculated in float32
|
||||
if model_config is not None:
|
||||
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
|
||||
@ -180,7 +185,10 @@ def _get_diffusion_model(
|
||||
key = (config_path, weights_location)
|
||||
if key not in LOADED_MODELS:
|
||||
MemoryAwareModel(
|
||||
config_path=config_path, weights_path=weights_location, half_mode=half_mode
|
||||
config_path=config_path,
|
||||
weights_path=weights_location,
|
||||
half_mode=half_mode,
|
||||
for_training=for_training,
|
||||
)
|
||||
|
||||
model = LOADED_MODELS[key]
|
||||
@ -190,6 +198,41 @@ def _get_diffusion_model(
|
||||
return model
|
||||
|
||||
|
||||
def resolve_model_paths(
|
||||
weights_path=iconfig.DEFAULT_MODEL,
|
||||
config_path=None,
|
||||
for_inpainting=False,
|
||||
for_training=False,
|
||||
):
|
||||
"""Resolve weight and config path if they happen to be shortcuts."""
|
||||
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
|
||||
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None)
|
||||
if for_inpainting:
|
||||
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(
|
||||
f"{weights_path}-inpaint", model_metadata_w
|
||||
)
|
||||
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(
|
||||
f"{config_path}-inpaint", model_metadata_c
|
||||
)
|
||||
|
||||
if model_metadata_w:
|
||||
if config_path is None:
|
||||
config_path = model_metadata_w.config_path
|
||||
if for_training:
|
||||
weights_path = model_metadata_w.weights_url_full
|
||||
if weights_path is None:
|
||||
raise ValueError(
|
||||
"No full training weights configured for this model. Edit the code or subimt a github issue."
|
||||
)
|
||||
else:
|
||||
weights_path = model_metadata_w.weights_url
|
||||
|
||||
if model_metadata_c:
|
||||
config_path = model_metadata_c.config_path
|
||||
model_metadata = model_metadata_w or model_metadata_c
|
||||
return model_metadata, weights_path, config_path
|
||||
|
||||
|
||||
def get_model_default_image_size(weights_location):
|
||||
model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None)
|
||||
if model_config:
|
||||
@ -209,12 +252,12 @@ def get_cache_dir():
|
||||
xdg_cache_home = os.path.join(user_home, ".cache")
|
||||
|
||||
if xdg_cache_home is not None:
|
||||
return os.path.join(xdg_cache_home, "imaginairy", "weights")
|
||||
return os.path.join(xdg_cache_home, "imaginairy")
|
||||
|
||||
return os.path.join(os.path.dirname(__file__), ".cached-downloads")
|
||||
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
|
||||
|
||||
|
||||
def get_cached_url_path(url):
|
||||
def get_cached_url_path(url, category=None):
|
||||
"""
|
||||
Gets the contents of a url, but caches the response indefinitely.
|
||||
|
||||
@ -231,6 +274,8 @@ def get_cached_url_path(url):
|
||||
pass
|
||||
filename = url.split("/")[-1]
|
||||
dest = get_cache_dir()
|
||||
if category:
|
||||
dest = os.path.join(dest, category)
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
dest_path = os.path.join(dest, filename)
|
||||
if os.path.exists(dest_path):
|
||||
@ -260,6 +305,20 @@ def check_huggingface_url_authorized(url):
|
||||
return None
|
||||
|
||||
|
||||
@wraps(_hf_hub_download)
|
||||
def hf_hub_download(*args, **kwargs):
|
||||
"""
|
||||
backwards compatible wrapper for huggingface's hf_hub_download.
|
||||
|
||||
they changed ther argument name from `use_auth_token` to `token`
|
||||
"""
|
||||
arg_names = inspect.getfullargspec(_hf_hub_download)
|
||||
if "use_auth_token" in arg_names.args and "token" in kwargs:
|
||||
kwargs["use_auth_token"] = kwargs.pop("token")
|
||||
|
||||
return _hf_hub_download(*args, **kwargs)
|
||||
|
||||
|
||||
def huggingface_cached_path(url):
|
||||
# bypass all the HEAD calls done by the default `cached_path`
|
||||
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
|
||||
|
@ -18,6 +18,7 @@ except ImportError:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
|
||||
ALLOW_SPLITMEM = True
|
||||
ATTENTION_PRECISION_OVERRIDE = "default"
|
||||
|
||||
|
||||
@ -174,10 +175,11 @@ class CrossAttention(nn.Module):
|
||||
# mask = _global_mask_hack.to(torch.bool)
|
||||
|
||||
if get_device() == "cuda" or "mps" in get_device():
|
||||
if not XFORMERS_IS_AVAILBLE:
|
||||
if not XFORMERS_IS_AVAILBLE and ALLOW_SPLITMEM:
|
||||
return self.forward_splitmem(x, context=context, mask=mask)
|
||||
|
||||
h = self.heads
|
||||
# print(x.shape)
|
||||
|
||||
q = self.to_q(x)
|
||||
context = context if context is not None else x
|
||||
@ -193,7 +195,8 @@ class CrossAttention(nn.Module):
|
||||
sim = einsum("b i d, b j d -> b i j", q, k)
|
||||
else:
|
||||
sim = einsum("b i d, b j d -> b i j", q, k)
|
||||
|
||||
# print(sim.shape)
|
||||
# print("*" * 100)
|
||||
del q, k
|
||||
# if mask is not None:
|
||||
# if sim.shape[2] == 320 and False:
|
||||
|
@ -14,12 +14,17 @@ import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.modules.utils import _pair
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torchvision.utils import make_grid
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy.modules.attention import CrossAttention
|
||||
from imaginairy.modules.autoencoder import AutoencoderKL, IdentityFirstStage
|
||||
from imaginairy.modules.diffusion.util import (
|
||||
extract_into_tensor,
|
||||
make_beta_schedule,
|
||||
@ -27,12 +32,39 @@ from imaginairy.modules.diffusion.util import (
|
||||
)
|
||||
from imaginairy.modules.distributions import DiagonalGaussianDistribution
|
||||
from imaginairy.modules.ema import LitEma
|
||||
from imaginairy.samplers.kdiff import DPMPP2MSampler
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"}
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = []
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(
|
||||
xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
|
||||
)
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def disabled_train(self):
|
||||
"""
|
||||
Overwrite model.train with this function to make sure train/eval mode
|
||||
@ -605,7 +637,9 @@ class DDPM(pl.LightningModule):
|
||||
return denoise_grid
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||
def log_images(
|
||||
self, batch, N=8, n_row=2, *, sample=True, return_keys=None, **kwargs
|
||||
):
|
||||
log = {}
|
||||
x = self.get_input(batch, self.first_stage_key)
|
||||
N = min(x.shape[0], N)
|
||||
@ -678,6 +712,7 @@ class LatentDiffusion(DDPM):
|
||||
conditioning_key=None,
|
||||
scale_factor=1.0,
|
||||
scale_by_std=False,
|
||||
unet_trainable=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_timesteps_cond = (
|
||||
@ -695,6 +730,7 @@ class LatentDiffusion(DDPM):
|
||||
super().__init__(conditioning_key=conditioning_key, **kwargs)
|
||||
self.concat_mode = concat_mode
|
||||
self.cond_stage_trainable = cond_stage_trainable
|
||||
self.unet_trainable = unet_trainable
|
||||
self.cond_stage_key = cond_stage_key
|
||||
try:
|
||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||
@ -726,6 +762,7 @@ class LatentDiffusion(DDPM):
|
||||
m._conv_forward = _TileModeConv2DConvForward.__get__( # noqa
|
||||
m, nn.Conv2d
|
||||
)
|
||||
self.tile_mode(tile_mode=False)
|
||||
|
||||
def tile_mode(self, tile_mode):
|
||||
"""For creating seamless tiles."""
|
||||
@ -1005,7 +1042,7 @@ class LatentDiffusion(DDPM):
|
||||
if cond_key is None:
|
||||
cond_key = self.cond_stage_key
|
||||
if cond_key != self.first_stage_key:
|
||||
if cond_key in ["caption", "coordinates_bbox"]:
|
||||
if cond_key in ["caption", "coordinates_bbox", "txt"]:
|
||||
xc = batch[cond_key]
|
||||
elif cond_key == "class_label":
|
||||
xc = batch
|
||||
@ -1100,6 +1137,24 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return self.first_stage_model.encode(x)
|
||||
|
||||
def shared_step(self, batch, **kwargs):
|
||||
x, c = self.get_input(batch, self.first_stage_key)
|
||||
loss = self(x, c)
|
||||
return loss
|
||||
|
||||
def forward(self, x, c, *args, **kwargs):
|
||||
t = torch.randint(
|
||||
0, self.num_timesteps, (x.shape[0],), device=self.device
|
||||
).long()
|
||||
if self.model.conditioning_key is not None:
|
||||
assert c is not None
|
||||
if self.cond_stage_trainable:
|
||||
c = self.get_learned_conditioning(c)
|
||||
if self.shorten_cond_schedule: # TODO: drop this option
|
||||
tc = self.cond_ids[t].to(self.device)
|
||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||
return self.p_losses(x, c, t, *args, **kwargs)
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||
|
||||
if isinstance(cond, dict):
|
||||
@ -1244,6 +1299,42 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
return x_recon
|
||||
|
||||
def p_losses(self, x_start, cond, t, noise=None): # noqa
|
||||
noise = noise if noise is not None else torch.randn_like(x_start)
|
||||
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
model_output = self.apply_model(x_noisy, t, cond)
|
||||
|
||||
loss_dict = {}
|
||||
prefix = "train" if self.training else "val"
|
||||
|
||||
if self.parameterization == "x0":
|
||||
target = x_start
|
||||
elif self.parameterization == "eps":
|
||||
target = noise
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
||||
loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
|
||||
|
||||
# t sometimes on wrong device. not sure why
|
||||
logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device)
|
||||
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
||||
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
||||
if self.learn_logvar:
|
||||
loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
|
||||
loss_dict.update({"logvar": self.logvar.data.mean()})
|
||||
|
||||
loss = self.l_simple_weight * loss.mean()
|
||||
|
||||
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
||||
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
||||
loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
|
||||
loss += self.original_elbo_weight * loss_vlb
|
||||
loss_dict.update({f"{prefix}/loss": loss})
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def p_mean_variance(
|
||||
self,
|
||||
x,
|
||||
@ -1346,6 +1437,287 @@ class LatentDiffusion(DDPM):
|
||||
* noise
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||
sampler = DPMPP2MSampler(self)
|
||||
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
||||
uncond = kwargs.get("unconditional_conditioning")
|
||||
if uncond is None:
|
||||
uncond = self.get_unconditional_conditioning(batch_size, "")
|
||||
|
||||
positive_conditioning = {
|
||||
"c_concat": [],
|
||||
"c_crossattn": [cond],
|
||||
}
|
||||
neutral_conditioning = {
|
||||
"c_concat": [],
|
||||
"c_crossattn": [uncond],
|
||||
}
|
||||
samples = sampler.sample(
|
||||
num_steps=ddim_steps,
|
||||
positive_conditioning=positive_conditioning,
|
||||
neutral_conditioning=neutral_conditioning,
|
||||
guidance_scale=kwargs.get("unconditional_guidance_scale", 5.0),
|
||||
shape=shape,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
return samples, []
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, batch_size, null_label=None):
|
||||
if null_label is not None:
|
||||
xc = null_label
|
||||
if isinstance(xc, ListConfig):
|
||||
xc = list(xc)
|
||||
if isinstance(xc, (dict, list)):
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
if hasattr(xc, "to"):
|
||||
xc = xc.to(self.device)
|
||||
c = self.get_learned_conditioning(xc)
|
||||
else:
|
||||
# todo: get null label from cond_stage_model
|
||||
raise NotImplementedError()
|
||||
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
|
||||
return c
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(
|
||||
self,
|
||||
batch,
|
||||
N=8,
|
||||
n_row=4,
|
||||
sample=True,
|
||||
ddim_steps=50,
|
||||
ddim_eta=1.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
inpaint=True,
|
||||
plot_denoise_rows=False,
|
||||
plot_progressive_rows=True,
|
||||
plot_diffusion_rows=True,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_guidance_label=None,
|
||||
use_ema_scope=True,
|
||||
**kwargs,
|
||||
):
|
||||
ema_scope = self.ema_scope if use_ema_scope else nullcontext
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = {}
|
||||
z, c, x, xrec, xc = self.get_input(
|
||||
batch,
|
||||
self.first_stage_key,
|
||||
return_first_stage_outputs=True,
|
||||
force_c_encode=True,
|
||||
return_original_cond=True,
|
||||
bs=N,
|
||||
)
|
||||
N = min(x.shape[0], N)
|
||||
n_row = min(x.shape[0], n_row)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
if self.model.conditioning_key is not None:
|
||||
if hasattr(self.cond_stage_model, "decode"):
|
||||
xc = self.cond_stage_model.decode(c)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key in ["caption", "txt"]:
|
||||
xc = log_txt_as_img(
|
||||
(x.shape[2], x.shape[3]),
|
||||
batch[self.cond_stage_key],
|
||||
size=x.shape[2] // 25,
|
||||
)
|
||||
log["conditioning"] = xc
|
||||
elif self.cond_stage_key == "class_label":
|
||||
# xc = log_txt_as_img(
|
||||
# (x.shape[2], x.shape[3]),
|
||||
# batch["human_label"],
|
||||
# size=x.shape[2] // 25,
|
||||
# )
|
||||
log["conditioning"] = xc
|
||||
elif isimage(xc):
|
||||
log["conditioning"] = xc
|
||||
if ismap(xc):
|
||||
log["original_conditioning"] = self.to_rgb(xc)
|
||||
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = []
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
|
||||
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
|
||||
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log["diffusion_row"] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
with ema_scope("Sampling"):
|
||||
samples, z_denoise_row = self.sample_log(
|
||||
cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log["samples"] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log["denoise_row"] = denoise_grid
|
||||
|
||||
if (
|
||||
quantize_denoised
|
||||
and not isinstance(self.first_stage_model, AutoencoderKL)
|
||||
and not isinstance(self.first_stage_model, IdentityFirstStage)
|
||||
):
|
||||
# also display when quantizing x0 while sampling
|
||||
with ema_scope("Plotting Quantized Denoised"):
|
||||
samples, z_denoise_row = self.sample_log(
|
||||
cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
quantize_denoised=True,
|
||||
)
|
||||
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
|
||||
# quantize_denoised=True)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_x0_quantized"] = x_samples
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
|
||||
# uc = torch.zeros_like(c)
|
||||
with ema_scope("Sampling with classifier-free guidance"):
|
||||
samples_cfg, _ = self.sample_log(
|
||||
cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[
|
||||
f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"
|
||||
] = x_samples_cfg
|
||||
|
||||
if inpaint:
|
||||
# make a simple center square
|
||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
||||
mask = torch.ones(N, h, w).to(self.device)
|
||||
# zeros will be filled in
|
||||
mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0
|
||||
mask = mask[:, None, ...]
|
||||
with ema_scope("Plotting Inpaint"):
|
||||
samples, _ = self.sample_log(
|
||||
cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
eta=ddim_eta,
|
||||
ddim_steps=ddim_steps,
|
||||
x0=z[:N],
|
||||
mask=mask,
|
||||
)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_inpainting"] = x_samples
|
||||
log["mask"] = mask
|
||||
|
||||
# outpaint
|
||||
mask = 1.0 - mask
|
||||
with ema_scope("Plotting Outpaint"):
|
||||
samples, _ = self.sample_log(
|
||||
cond=c,
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
eta=ddim_eta,
|
||||
ddim_steps=ddim_steps,
|
||||
x0=z[:N],
|
||||
mask=mask,
|
||||
)
|
||||
x_samples = self.decode_first_stage(samples.to(self.device))
|
||||
log["samples_outpainting"] = x_samples
|
||||
|
||||
if plot_progressive_rows:
|
||||
with ema_scope("Plotting Progressives"):
|
||||
img, progressives = self.progressive_denoising(
|
||||
c,
|
||||
shape=(self.channels, self.image_size, self.image_size),
|
||||
batch_size=N,
|
||||
)
|
||||
prog_row = self._get_denoise_row_from_list(
|
||||
progressives, desc="Progressive Generation"
|
||||
)
|
||||
log["progressive_row"] = prog_row
|
||||
|
||||
if return_keys:
|
||||
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
|
||||
return log
|
||||
|
||||
return {key: log[key] for key in return_keys}
|
||||
return log
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = []
|
||||
if self.unet_trainable == "attn":
|
||||
logger.info("Training only unet attention layers")
|
||||
for n, m in self.model.named_modules():
|
||||
if isinstance(m, CrossAttention) and n.endswith("attn2"):
|
||||
params.extend(m.parameters())
|
||||
elif self.unet_trainable is True or self.unet_trainable == "all":
|
||||
logger.info("Training the full unet")
|
||||
params = list(self.model.parameters())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognised setting for unet_trainable: {self.unet_trainable}"
|
||||
)
|
||||
|
||||
if self.cond_stage_trainable:
|
||||
logger.info(
|
||||
f"{self.__class__.__name__}: Also optimizing conditioner params!"
|
||||
)
|
||||
params = params + list(self.cond_stage_model.parameters())
|
||||
if self.learn_logvar:
|
||||
logger.info("Diffusion model optimizing logvar")
|
||||
params.append(self.logvar)
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
if self.use_scheduler:
|
||||
assert "target" in self.scheduler_config
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
logger.info("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
||||
"interval": "step",
|
||||
"frequency": 1,
|
||||
}
|
||||
]
|
||||
return [opt], scheduler
|
||||
return opt
|
||||
|
||||
@torch.no_grad()
|
||||
def to_rgb(self, x):
|
||||
x = x.float()
|
||||
if not hasattr(self, "colorize"):
|
||||
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) # noqa
|
||||
x = nn.functional.conv2d(x, weight=self.colorize)
|
||||
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
||||
return x
|
||||
|
||||
|
||||
class DiffusionWrapper(pl.LightningModule):
|
||||
def __init__(self, diff_model_config, conditioning_key):
|
||||
|
@ -12,7 +12,7 @@ class BaseModel(torch.nn.Module): # noqa
|
||||
Args:
|
||||
path (str): file path
|
||||
"""
|
||||
ckpt_path = get_cached_url_path(config.midas_url)
|
||||
ckpt_path = get_cached_url_path(config.midas_url, category="weights")
|
||||
parameters = torch.load(ckpt_path, map_location=torch.device("cpu"))
|
||||
|
||||
if "optimizer" in parameters:
|
||||
|
103
imaginairy/roi_utils.py
Normal file
103
imaginairy/roi_utils.py
Normal file
@ -0,0 +1,103 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def square_roi_coordinate(roi, max_width, max_height, best_effort=False):
|
||||
"""Given a region of interest, returns a square region of interest."""
|
||||
x1, y1, x2, y2 = roi
|
||||
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
|
||||
roi_width = x2 - x1
|
||||
roi_height = y2 - y1
|
||||
if roi_width < roi_height:
|
||||
diff = roi_height - roi_width
|
||||
x1 -= int(round(diff / 2))
|
||||
x2 += roi_height - (x2 - x1)
|
||||
elif roi_height < roi_width:
|
||||
diff = roi_width - roi_height
|
||||
y1 -= int(round(diff / 2))
|
||||
y2 += roi_width - (y2 - y1)
|
||||
|
||||
x1, y1, x2, y2 = move_roi_into_bounds(
|
||||
(x1, y1, x2, y2), max_width, max_height, force=best_effort
|
||||
)
|
||||
width = x2 - x1
|
||||
height = y2 - y1
|
||||
if not best_effort and width != height:
|
||||
raise RuntimeError(f"ROI is not square: {width}x{height}")
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def resize_roi_coordinates(
|
||||
roi, expansion_factor, max_width, max_height, expand_up=True
|
||||
):
|
||||
"""
|
||||
Resize a region of interest while staying within the bounds.
|
||||
|
||||
setting expand_up to False will prevent the ROI from expanding upwards, which is useful when
|
||||
expanding something like a face roi to capture more of the person instead of empty space above them.
|
||||
|
||||
"""
|
||||
x1, y1, x2, y2 = roi
|
||||
side_length_x = x2 - x1
|
||||
side_length_y = y2 - y1
|
||||
|
||||
max_expansion_factor = min(max_height / side_length_y, max_width / side_length_x)
|
||||
expansion_factor = min(expansion_factor, max_expansion_factor)
|
||||
|
||||
expansion_x = int(round(side_length_x * expansion_factor - side_length_x))
|
||||
expansion_x_a = int(round(expansion_x / 2))
|
||||
expansion_x_b = expansion_x - expansion_x_a
|
||||
x1 -= expansion_x_a
|
||||
x2 += expansion_x_b
|
||||
|
||||
expansion_y = int(round(side_length_y * expansion_factor - side_length_y))
|
||||
if expand_up:
|
||||
expansion_y_a = int(round(expansion_y / 2))
|
||||
expansion_y_b = expansion_y - expansion_y_a
|
||||
y1 -= expansion_y_a
|
||||
y2 += expansion_y_b
|
||||
else:
|
||||
y2 += expansion_y
|
||||
|
||||
x1, y1, x2, y2 = move_roi_into_bounds((x1, y1, x2, y2), max_width, max_height)
|
||||
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
class RoiNotInBoundsError(ValueError):
|
||||
"""Error raised when a ROI is not within the bounds of the image."""
|
||||
|
||||
|
||||
def move_roi_into_bounds(roi, max_width, max_height, force=False):
|
||||
"""Move a region of interest into the bounds of the image."""
|
||||
x1, y1, x2, y2 = roi
|
||||
|
||||
# move the ROI within the image boundaries
|
||||
if x1 < 0:
|
||||
x2 -= x1
|
||||
x1 = 0
|
||||
if y1 < 0:
|
||||
y2 -= y1
|
||||
y1 = 0
|
||||
if x2 > max_width:
|
||||
x1 -= x2 - max_width
|
||||
x2 = max_width
|
||||
if y2 > max_height:
|
||||
y1 -= y2 - max_height
|
||||
y2 = max_height
|
||||
x1, y1, x2, y2 = int(round(x1)), int(round(y1)), int(round(x2)), int(round(y2))
|
||||
# Force ROI to fit within image boundaries (sacrificing size and aspect ratio of ROI)
|
||||
if force:
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(max_width, x2)
|
||||
y2 = min(max_height, y2)
|
||||
if x1 < 0 or y1 < 0 or x2 > max_width or y2 > max_height:
|
||||
roi_width = x2 - x1
|
||||
roi_height = y2 - y1
|
||||
raise RoiNotInBoundsError(
|
||||
f"Not possible to fit ROI into boundaries: {roi_width}x{roi_height} won't fit inside {max_width}x{max_height}"
|
||||
)
|
||||
|
||||
return x1, y1, x2, y2
|
@ -47,6 +47,13 @@ class DDIMSampler(ImageSampler):
|
||||
t_start=None,
|
||||
quantize_x0=False,
|
||||
):
|
||||
# print("Sampling with DDIM")
|
||||
# print("num_steps", num_steps)
|
||||
# print("shape", shape)
|
||||
# print("neutral_conditioning", neutral_conditioning)
|
||||
# print("positive_conditioning", positive_conditioning)
|
||||
# print("guidance_scale", guidance_scale)
|
||||
# print("batch_size", batch_size)
|
||||
schedule = NoiseSchedule(
|
||||
model_num_timesteps=self.model.num_timesteps,
|
||||
model_alphas_cumprod=self.model.alphas_cumprod,
|
||||
|
@ -56,14 +56,14 @@ class LazyLoadingImage:
|
||||
|
||||
if self._lazy_filepath:
|
||||
self._img = Image.open(self._lazy_filepath)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
|
||||
)
|
||||
elif self._lazy_url:
|
||||
self._img = Image.open(
|
||||
requests.get(self._lazy_url, stream=True, timeout=60).raw
|
||||
)
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_url}"
|
||||
)
|
||||
# fix orientation
|
||||
@ -113,6 +113,7 @@ class ImaginePrompt:
|
||||
conditioning=None,
|
||||
tile_mode="",
|
||||
model=config.DEFAULT_MODEL,
|
||||
model_config_path=None,
|
||||
):
|
||||
|
||||
self.prompts = self.process_prompt_input(prompt)
|
||||
@ -165,6 +166,7 @@ class ImaginePrompt:
|
||||
|
||||
if self.model == "SD-2.0-v" and self.sampler_type == SamplerName.PLMS:
|
||||
raise ValueError("PLMS sampler is not supported for SD-2.0-v model.")
|
||||
self.model_config_path = model_config_path
|
||||
|
||||
@property
|
||||
def prompt_text(self):
|
||||
|
520
imaginairy/train.py
Normal file
520
imaginairy/train.py
Normal file
@ -0,0 +1,520 @@
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import Callback, LearningRateMonitor
|
||||
from pytorch_lightning.strategies import DDPStrategy
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.model_manager import get_diffusion_model
|
||||
from imaginairy.training_tools.single_concept import SingleConceptDataset
|
||||
from imaginairy.utils import get_device, instantiate_from_config
|
||||
|
||||
mod_logger = logging.getLogger(__name__)
|
||||
|
||||
referenced_by_string = [LearningRateMonitor]
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset."""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, SingleConceptDataset):
|
||||
# split_size = dataset.num_records // worker_info.num_workers
|
||||
# reset num_records to the true number to retain reliable length information
|
||||
# dataset.sample_ids = dataset.valid_ids[
|
||||
# worker_id * split_size : (worker_id + 1) * split_size
|
||||
# ]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=False,
|
||||
num_val_workers=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = {}
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
if num_val_workers is None:
|
||||
self.num_val_workers = self.num_workers
|
||||
else:
|
||||
self.num_val_workers = num_val_workers
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(
|
||||
self._val_dataloader, shuffle=shuffle_val_dataloader
|
||||
)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(
|
||||
self._test_dataloader, shuffle=shuffle_test_loader
|
||||
)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
self.wrap = wrap
|
||||
self.datasets = None
|
||||
|
||||
def prepare_data(self):
|
||||
for data_cfg in self.dataset_configs.values():
|
||||
instantiate_from_config(data_cfg)
|
||||
|
||||
def setup(self, stage=None):
|
||||
self.datasets = {
|
||||
k: instantiate_from_config(c) for k, c in self.dataset_configs.items()
|
||||
}
|
||||
if self.wrap:
|
||||
self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()}
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
return DataLoader(
|
||||
self.datasets["train"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
shuffle=True,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["validation"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["validation"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_val_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset)
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
is_iterable_dataset = False
|
||||
|
||||
# do not shuffle dataloader for iterable dataset
|
||||
shuffle = shuffle and (not is_iterable_dataset)
|
||||
|
||||
return DataLoader(
|
||||
self.datasets["test"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
shuffle=shuffle,
|
||||
)
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if (
|
||||
isinstance(self.datasets["predict"], SingleConceptDataset)
|
||||
or self.use_worker_init_fn
|
||||
):
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
)
|
||||
|
||||
|
||||
class SetupCallback(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
resume,
|
||||
now,
|
||||
logdir,
|
||||
ckptdir,
|
||||
cfgdir,
|
||||
):
|
||||
super().__init__()
|
||||
self.resume = resume
|
||||
self.now = now
|
||||
self.logdir = logdir
|
||||
self.ckptdir = ckptdir
|
||||
self.cfgdir = cfgdir
|
||||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Stopping execution and saving final checkpoint.")
|
||||
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
if trainer.global_rank == 0:
|
||||
# Create logdirs and save configs
|
||||
os.makedirs(self.logdir, exist_ok=True)
|
||||
os.makedirs(self.ckptdir, exist_ok=True)
|
||||
os.makedirs(self.cfgdir, exist_ok=True)
|
||||
|
||||
else:
|
||||
# ModelCheckpoint callback created log directory --- remove it
|
||||
if not self.resume and os.path.exists(self.logdir):
|
||||
dst, name = os.path.split(self.logdir)
|
||||
dst = os.path.join(dst, "child_runs", name)
|
||||
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
||||
try:
|
||||
os.rename(self.logdir, dst)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
def __init__(
|
||||
self,
|
||||
batch_frequency,
|
||||
max_images,
|
||||
clamp=True,
|
||||
increase_log_steps=True,
|
||||
rescale=True,
|
||||
disabled=False,
|
||||
log_on_batch_idx=False,
|
||||
log_first_step=False,
|
||||
log_images_kwargs=None,
|
||||
log_all_val=False,
|
||||
concept_label=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.logger_log_images = {}
|
||||
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||
if not increase_log_steps:
|
||||
self.log_steps = [self.batch_freq]
|
||||
self.clamp = clamp
|
||||
self.disabled = disabled
|
||||
self.log_on_batch_idx = log_on_batch_idx
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.log_first_step = log_first_step
|
||||
self.log_all_val = log_all_val
|
||||
self.concept_label = concept_label
|
||||
|
||||
@rank_zero_only
|
||||
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
||||
root = os.path.join(save_dir, "logs", "images", split)
|
||||
for k in images:
|
||||
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
||||
if self.rescale:
|
||||
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
filename = (
|
||||
f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png"
|
||||
)
|
||||
path = os.path.join(root, filename)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
||||
# always generate the concept label
|
||||
batch["txt"][0] = self.concept_label
|
||||
|
||||
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
||||
if self.log_all_val and split == "val":
|
||||
should_log = True
|
||||
else:
|
||||
should_log = self.check_frequency(check_idx)
|
||||
if (
|
||||
should_log
|
||||
and (batch_idx % self.batch_freq == 0)
|
||||
and hasattr(pl_module, "log_images")
|
||||
and callable(pl_module.log_images)
|
||||
and self.max_images > 0
|
||||
):
|
||||
logger = type(pl_module.logger)
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = pl_module.log_images(
|
||||
batch, split=split, **self.log_images_kwargs
|
||||
)
|
||||
|
||||
for k in images:
|
||||
N = min(images[k].shape[0], self.max_images)
|
||||
images[k] = images[k][:N]
|
||||
if isinstance(images[k], torch.Tensor):
|
||||
images[k] = images[k].detach().cpu()
|
||||
if self.clamp:
|
||||
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
||||
|
||||
self.log_local(
|
||||
pl_module.logger.save_dir,
|
||||
split,
|
||||
images,
|
||||
pl_module.global_step,
|
||||
pl_module.current_epoch,
|
||||
batch_idx,
|
||||
)
|
||||
|
||||
logger_log_images = self.logger_log_images.get(
|
||||
logger, lambda *args, **kwargs: None
|
||||
)
|
||||
logger_log_images(pl_module, images, pl_module.global_step, split)
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def check_frequency(self, check_idx):
|
||||
if (check_idx % self.batch_freq) == 0 and (
|
||||
check_idx > 0 or self.log_first_step
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(
|
||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
):
|
||||
if not self.disabled and pl_module.global_step > 0:
|
||||
self.log_img(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, "calibrate_grad_norm"):
|
||||
if (
|
||||
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
|
||||
) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
self.start_time = time.time() # noqa
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module): # noqa
|
||||
if "cuda" in get_device():
|
||||
torch.cuda.synchronize(trainer.strategy.root_device.index)
|
||||
max_memory = (
|
||||
torch.cuda.max_memory_allocated(trainer.strategy.root_device.index)
|
||||
/ 2**20
|
||||
)
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def train_diffusion_model(
|
||||
concept_label,
|
||||
concept_images_dir,
|
||||
class_label,
|
||||
class_images_dir,
|
||||
weights_location=config.DEFAULT_MODEL,
|
||||
logdir="logs",
|
||||
learning_rate=1e-6,
|
||||
accumulate_grad_batches=32,
|
||||
resume=None,
|
||||
):
|
||||
batch_size = 1
|
||||
seed = 23
|
||||
num_workers = 1
|
||||
num_val_workers = 0
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
logdir = os.path.join(logdir, now)
|
||||
|
||||
ckpt_output_dir = os.path.join(logdir, "checkpoints")
|
||||
cfg_output_dir = os.path.join(logdir, "configs")
|
||||
seed_everything(seed)
|
||||
model = get_diffusion_model( # noqa
|
||||
weights_location=weights_location, half_mode=False, for_training=True
|
||||
)._model
|
||||
model.learning_rate = learning_rate * accumulate_grad_batches * batch_size
|
||||
|
||||
# add callback which sets up log directory
|
||||
default_callbacks_cfg = {
|
||||
"setup_callback": {
|
||||
"target": "imaginairy.train.SetupCallback",
|
||||
"params": {
|
||||
"resume": False,
|
||||
"now": now,
|
||||
"logdir": logdir,
|
||||
"ckptdir": ckpt_output_dir,
|
||||
"cfgdir": cfg_output_dir,
|
||||
},
|
||||
},
|
||||
"image_logger": {
|
||||
"target": "imaginairy.train.ImageLogger",
|
||||
"params": {
|
||||
"batch_frequency": 10,
|
||||
"max_images": 1,
|
||||
"clamp": True,
|
||||
"increase_log_steps": False,
|
||||
"log_first_step": True,
|
||||
"log_all_val": True,
|
||||
"concept_label": concept_label,
|
||||
"log_images_kwargs": {
|
||||
"use_ema_scope": True,
|
||||
"inpaint": False,
|
||||
"plot_progressive_rows": False,
|
||||
"plot_diffusion_rows": False,
|
||||
"N": 1,
|
||||
"unconditional_guidance_scale:": 7.5,
|
||||
"unconditional_guidance_label": [""],
|
||||
"ddim_steps": 20,
|
||||
},
|
||||
},
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "imaginairy.train.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
# "log_momentum": True
|
||||
},
|
||||
},
|
||||
"cuda_callback": {"target": "imaginairy.train.CUDACallback"},
|
||||
}
|
||||
|
||||
default_modelckpt_cfg = {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckpt_output_dir,
|
||||
"filename": "{epoch:06}",
|
||||
"verbose": True,
|
||||
"save_last": True,
|
||||
"every_n_train_steps": 50,
|
||||
"save_top_k": -1,
|
||||
"monitor": None,
|
||||
},
|
||||
}
|
||||
|
||||
modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg)
|
||||
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
|
||||
|
||||
callbacks_cfg = OmegaConf.create(default_callbacks_cfg)
|
||||
|
||||
dataset_config = {
|
||||
"concept_label": concept_label,
|
||||
"concept_images_dir": concept_images_dir,
|
||||
"class_label": class_label,
|
||||
"class_images_dir": class_images_dir,
|
||||
"image_transforms": [
|
||||
{
|
||||
"target": "torchvision.transforms.Resize",
|
||||
"params": {"size": 512, "interpolation": 3},
|
||||
},
|
||||
{"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}},
|
||||
],
|
||||
}
|
||||
|
||||
data_module_config = {
|
||||
"batch_size": batch_size,
|
||||
"num_workers": num_workers,
|
||||
"num_val_workers": num_val_workers,
|
||||
"train": {
|
||||
"target": "imaginairy.training_tools.single_concept.SingleConceptDataset",
|
||||
"params": dataset_config,
|
||||
},
|
||||
}
|
||||
trainer = Trainer(
|
||||
benchmark=True,
|
||||
num_sanity_val_steps=0,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
strategy=DDPStrategy(),
|
||||
callbacks=[
|
||||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg # noqa
|
||||
],
|
||||
gpus=1,
|
||||
default_root_dir=".",
|
||||
)
|
||||
trainer.logdir = logdir
|
||||
|
||||
data = DataModuleFromConfig(**data_module_config)
|
||||
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
||||
# calling these ourselves should not be necessary but it is.
|
||||
# lightning still takes care of proper multiprocessing though
|
||||
data.prepare_data()
|
||||
data.setup()
|
||||
|
||||
def melk(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
mod_logger.info("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
try:
|
||||
|
||||
try:
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
melk()
|
||||
raise
|
||||
finally:
|
||||
mod_logger.info(trainer.profiler.summary())
|
0
imaginairy/training_tools/__init__.py
Normal file
0
imaginairy/training_tools/__init__.py
Normal file
155
imaginairy/training_tools/image_prep.py
Normal file
155
imaginairy/training_tools/image_prep.py
Normal file
@ -0,0 +1,155 @@
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy import ImaginePrompt, LazyLoadingImage, imagine
|
||||
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
|
||||
from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops
|
||||
from imaginairy.enhancers.upscale_realesrgan import upscale_image
|
||||
from imaginairy.vendored.smart_crop import SmartCrop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_image_filenames(folder):
|
||||
filenames = []
|
||||
for filename in os.listdir(folder):
|
||||
if not filename.lower().endswith((".jpg", ".jpeg", ".png")):
|
||||
continue
|
||||
if filename.startswith("."):
|
||||
continue
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def prep_images(
|
||||
images_dir, is_person=False, output_folder_name="prepped-images", target_size=512
|
||||
):
|
||||
"""
|
||||
Crops and resizes a directory of images in preparation for training.
|
||||
|
||||
If is_person=True, it will detect the face and produces several crops at different zoom levels. For crops that
|
||||
are too small, it will use the face restoration model to enhance the faces.
|
||||
|
||||
For non-person images, it will use the smartcrop algorithm to crop the image to the most interesting part. If the
|
||||
input image is too small it will be upscaled.
|
||||
|
||||
Prep will go a lot faster if all the images are big enough to not require upscaling.
|
||||
|
||||
"""
|
||||
output_folder = os.path.join(images_dir, output_folder_name)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
logger.info(f"Prepping images in {images_dir} to {output_folder}")
|
||||
image_filenames = get_image_filenames(images_dir)
|
||||
pbar = tqdm(image_filenames)
|
||||
for filename in pbar:
|
||||
pbar.set_description(filename)
|
||||
|
||||
input_path = os.path.join(images_dir, filename)
|
||||
img = LazyLoadingImage(filepath=input_path).convert("RGB")
|
||||
if is_person:
|
||||
face_rois = detect_faces(img)
|
||||
if len(face_rois) == 0:
|
||||
logger.info(f"No faces detected in image {filename}, skipping")
|
||||
continue
|
||||
if len(face_rois) > 1:
|
||||
logger.info(f"Multiple faces detected in image {filename}, skipping")
|
||||
continue
|
||||
face_roi = face_rois[0]
|
||||
face_roi_crops = generate_face_crops(
|
||||
face_roi, max_width=img.width, max_height=img.height
|
||||
)
|
||||
for n, face_roi_crop in enumerate(face_roi_crops):
|
||||
cropped_output_path = os.path.join(
|
||||
output_folder, f"{filename}_[alt-{n:02d}].jpg"
|
||||
)
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
x1, y1, x2, y2 = face_roi_crop
|
||||
crop_width = x2 - x1
|
||||
crop_height = y2 - y1
|
||||
if crop_width != crop_height:
|
||||
logger.info(
|
||||
f"Face ROI crop for {filename} {crop_width}x{crop_height} is not square, skipping"
|
||||
)
|
||||
continue
|
||||
cropped_img = img.crop(face_roi_crop)
|
||||
|
||||
if crop_width < target_size:
|
||||
logger.info(f"Upscaling {filename} {face_roi_crop}")
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img = enhance_faces(cropped_img, fidelity=1)
|
||||
else:
|
||||
cropped_img = cropped_img.resize(
|
||||
(target_size, target_size), resample=Image.Resampling.LANCZOS
|
||||
)
|
||||
cropped_img.save(cropped_output_path, quality=95)
|
||||
else:
|
||||
# scale image so that largest dimension is target_size
|
||||
n = 0
|
||||
cropped_output_path = os.path.join(output_folder, f"{filename}_{n}.jpg")
|
||||
if os.path.exists(cropped_output_path):
|
||||
logger.debug(
|
||||
f"Skipping {cropped_output_path} because it already exists"
|
||||
)
|
||||
continue
|
||||
if img.width < target_size or img.height < target_size:
|
||||
# upscale the image if it's too small
|
||||
logger.info(f"Upscaling {filename}")
|
||||
img = upscale_image(img)
|
||||
|
||||
if img.width > img.height:
|
||||
scale_factor = target_size / img.height
|
||||
else:
|
||||
scale_factor = target_size / img.width
|
||||
|
||||
# downscale so shortest side is target_size
|
||||
new_width = int(round(img.width * scale_factor))
|
||||
new_height = int(round(img.height * scale_factor))
|
||||
img = img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
result = SmartCrop().crop(img, width=target_size, height=target_size)
|
||||
|
||||
box = (
|
||||
result["top_crop"]["x"],
|
||||
result["top_crop"]["y"],
|
||||
result["top_crop"]["width"] + result["top_crop"]["x"],
|
||||
result["top_crop"]["height"] + result["top_crop"]["y"],
|
||||
)
|
||||
|
||||
cropped_image = img.crop(box)
|
||||
cropped_image.save(cropped_output_path, quality=95)
|
||||
logger.info(f"Image Prep complete. Review output at {output_folder}")
|
||||
|
||||
|
||||
def prompt_normalized(prompt):
|
||||
return re.sub(r"[^a-zA-Z0-9.,\[\]-]+", "-", prompt)[:130]
|
||||
|
||||
|
||||
def create_class_images(class_description, output_folder, num_images=200):
|
||||
"""
|
||||
Generate images of class_description.
|
||||
"""
|
||||
existing_images = get_image_filenames(output_folder)
|
||||
existing_image_count = len(existing_images)
|
||||
class_slug = prompt_normalized(class_description)
|
||||
|
||||
while existing_image_count < num_images:
|
||||
prompt = ImaginePrompt(class_description, steps=20)
|
||||
result = list(imagine([prompt]))[0]
|
||||
if result.is_nsfw:
|
||||
continue
|
||||
dest = os.path.join(
|
||||
output_folder, f"{existing_image_count:03d}_{class_slug}.jpg"
|
||||
)
|
||||
result.save(dest)
|
||||
existing_image_count += 1
|
31
imaginairy/training_tools/prune_model.py
Normal file
31
imaginairy/training_tools/prune_model.py
Normal file
@ -0,0 +1,31 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prune_diffusion_ckpt(ckpt_path, dst_path=None):
|
||||
if dst_path is None:
|
||||
dst_path = f"{os.path.splitext(ckpt_path)[0]}-pruned.ckpt"
|
||||
|
||||
data = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
new_data = prune_model_data(data)
|
||||
|
||||
torch.save(new_data, dst_path)
|
||||
|
||||
size_initial = os.path.getsize(ckpt_path)
|
||||
newsize = os.path.getsize(dst_path)
|
||||
msg = (
|
||||
f"New ckpt size: {newsize * 1e-9:.2f} GB. "
|
||||
f"Saved {(size_initial - newsize) * 1e-9:.2f} GB by removing optimizer states"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
def prune_model_data(data):
|
||||
skip_keys = {"optimizer_states"}
|
||||
new_data = {k: v for k, v in data.items() if k not in skip_keys}
|
||||
return new_data
|
135
imaginairy/training_tools/single_concept.py
Normal file
135
imaginairy/training_tools/single_concept.py
Normal file
@ -0,0 +1,135 @@
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
|
||||
from einops import rearrange
|
||||
from omegaconf import ListConfig
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, IterableDataset
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from imaginairy.utils import instantiate_from_config
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
"""
|
||||
Define an interface to make the IterableDatasets for text2img data chainable.
|
||||
"""
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(f"{self.__class__.__name__} dataset contains {self.__len__()} examples.")
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _rearrange(x):
|
||||
return rearrange(x * 2.0 - 1.0, "c h w -> h w c")
|
||||
|
||||
|
||||
class SingleConceptDataset(Dataset):
|
||||
"""
|
||||
Dataset for finetuning a model on a single concept.
|
||||
|
||||
Similar to "dreambooth"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
concept_label,
|
||||
class_label,
|
||||
concept_images_dir,
|
||||
class_images_dir,
|
||||
image_transforms=None,
|
||||
):
|
||||
self.concept_label = concept_label
|
||||
self.class_label = class_label
|
||||
self.concept_images_dir = concept_images_dir
|
||||
self.class_images_dir = class_images_dir
|
||||
|
||||
if isinstance(image_transforms, (ListConfig, list)):
|
||||
image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
|
||||
image_transforms.extend(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(_rearrange),
|
||||
]
|
||||
)
|
||||
image_transforms = transforms.Compose(image_transforms)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
self._concept_image_filename_groups = None
|
||||
self._class_image_filenames = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.concept_image_filename_groups) * 2
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx % 2:
|
||||
img_group = self._concept_image_filename_groups[int(idx / 2)]
|
||||
img_filename = random.choice(img_group)
|
||||
img_path = os.path.join(self.concept_images_dir, img_filename)
|
||||
|
||||
txt = self.concept_label
|
||||
else:
|
||||
img_path = os.path.join(
|
||||
self.class_images_dir, self.class_image_filenames[int(idx / 2)]
|
||||
)
|
||||
txt = self.class_label
|
||||
try:
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f"Could not read image {img_path}") from e
|
||||
image = self.image_transforms(image)
|
||||
data = {"image": image, "txt": txt}
|
||||
return data
|
||||
|
||||
@property
|
||||
def concept_image_filename_groups(self):
|
||||
if self._concept_image_filename_groups is None:
|
||||
self._concept_image_filename_groups = _load_image_filenames_and_alts(
|
||||
self.concept_images_dir
|
||||
)
|
||||
return self._concept_image_filename_groups
|
||||
|
||||
@property
|
||||
def class_image_filenames(self):
|
||||
if self._class_image_filenames is None:
|
||||
self._class_image_filenames = _load_image_filenames(self.class_images_dir)
|
||||
return self._class_image_filenames
|
||||
|
||||
@property
|
||||
def num_records(self):
|
||||
return len(self)
|
||||
|
||||
|
||||
def _load_image_filenames_and_alts(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
"""Loads images into groups (filenames tagged with `[alt-{n:02d}]` are grouped together)."""
|
||||
image_filenames = _load_image_filenames(img_dir, image_extensions)
|
||||
grouped_img_filenames = defaultdict(list)
|
||||
for filename in image_filenames:
|
||||
base_filename = re.sub(r"\[alt-\d*\]", "", filename)
|
||||
grouped_img_filenames[base_filename].append(filename)
|
||||
return list(grouped_img_filenames.values())
|
||||
|
||||
|
||||
def _load_image_filenames(img_dir, image_extensions=(".jpg", ".jpeg", ".png")):
|
||||
image_filenames = []
|
||||
for filename in os.listdir(img_dir):
|
||||
if filename.lower().endswith(image_extensions) and not filename.startswith("."):
|
||||
image_filenames.append(filename)
|
||||
random.shuffle(image_filenames)
|
||||
return image_filenames
|
375
imaginairy/vendored/smart_crop.py
Normal file
375
imaginairy/vendored/smart_crop.py
Normal file
@ -0,0 +1,375 @@
|
||||
"""
|
||||
Crops to the most interesting part of the image.
|
||||
|
||||
MIT License from https://github.com/smartcrop/smartcrop.py/commit/f5377045035abc7ae79d8d9ad40bbc7fce0f6ad7
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL.ImageFilter import Kernel
|
||||
|
||||
|
||||
def saturation(image):
|
||||
r, g, b = image.split()
|
||||
r, g, b = np.array(r), np.array(g), np.array(b)
|
||||
r, g, b = r.astype(float), g.astype(float), b.astype(float)
|
||||
maximum = np.maximum(np.maximum(r, g), b) # [0; 255]
|
||||
minimum = np.minimum(np.minimum(r, g), b) # [0; 255]
|
||||
s = (maximum + minimum) / 255 # [0.0; 1.0]
|
||||
d = (maximum - minimum) / 255 # [0.0; 1.0]
|
||||
d[maximum == minimum] = 0 # if maximum == minimum:
|
||||
s[maximum == minimum] = 1 # -> saturation = 0 / 1 = 0
|
||||
mask = s > 1
|
||||
s[mask] = 2 - d[mask]
|
||||
return d / s # [0.0; 1.0]
|
||||
|
||||
|
||||
def thirds(x):
|
||||
"""gets value in the range of [0, 1] where 0 is the center of the pictures
|
||||
returns weight of rule of thirds [0, 1]."""
|
||||
x = ((x + 2 / 3) % 2 * 0.5 - 0.5) * 16
|
||||
return max(1 - x * x, 0)
|
||||
|
||||
|
||||
class SmartCrop:
|
||||
DEFAULT_SKIN_COLOR = [0.78, 0.57, 0.44]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detail_weight=0.2,
|
||||
edge_radius=0.4,
|
||||
edge_weight=-20,
|
||||
outside_importance=-0.5,
|
||||
rule_of_thirds=True,
|
||||
saturation_bias=0.2,
|
||||
saturation_brightness_max=0.9,
|
||||
saturation_brightness_min=0.05,
|
||||
saturation_threshold=0.4,
|
||||
saturation_weight=0.3,
|
||||
score_down_sample=8,
|
||||
skin_bias=0.01,
|
||||
skin_brightness_max=1,
|
||||
skin_brightness_min=0.2,
|
||||
skin_color=None,
|
||||
skin_threshold=0.8,
|
||||
skin_weight=1.8,
|
||||
):
|
||||
self.detail_weight = detail_weight
|
||||
self.edge_radius = edge_radius
|
||||
self.edge_weight = edge_weight
|
||||
self.outside_importance = outside_importance
|
||||
self.rule_of_thirds = rule_of_thirds
|
||||
self.saturation_bias = saturation_bias
|
||||
self.saturation_brightness_max = saturation_brightness_max
|
||||
self.saturation_brightness_min = saturation_brightness_min
|
||||
self.saturation_threshold = saturation_threshold
|
||||
self.saturation_weight = saturation_weight
|
||||
self.score_down_sample = score_down_sample
|
||||
self.skin_bias = skin_bias
|
||||
self.skin_brightness_max = skin_brightness_max
|
||||
self.skin_brightness_min = skin_brightness_min
|
||||
self.skin_color = skin_color or self.DEFAULT_SKIN_COLOR
|
||||
self.skin_threshold = skin_threshold
|
||||
self.skin_weight = skin_weight
|
||||
|
||||
def analyse(
|
||||
self,
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
"""
|
||||
Analyze image and return some suggestions of crops (coordinates).
|
||||
This implementation / algorithm is really slow for large images.
|
||||
Use `crop()` which is pre-scaling the image before analyzing it.
|
||||
"""
|
||||
cie_image = image.convert("L", (0.2126, 0.7152, 0.0722, 0))
|
||||
cie_array = np.array(cie_image) # [0; 255]
|
||||
|
||||
# R=skin G=edge B=saturation
|
||||
edge_image = self.detect_edge(cie_image)
|
||||
skin_image = self.detect_skin(cie_array, image)
|
||||
saturation_image = self.detect_saturation(cie_array, image)
|
||||
analyse_image = Image.merge("RGB", [skin_image, edge_image, saturation_image])
|
||||
|
||||
del edge_image
|
||||
del skin_image
|
||||
del saturation_image
|
||||
|
||||
score_image = analyse_image.copy()
|
||||
score_image.thumbnail(
|
||||
(
|
||||
int(math.ceil(image.size[0] / self.score_down_sample)),
|
||||
int(math.ceil(image.size[1] / self.score_down_sample)),
|
||||
),
|
||||
Image.ANTIALIAS,
|
||||
)
|
||||
|
||||
top_crop = None
|
||||
top_score = -sys.maxsize
|
||||
|
||||
crops = self.crops(
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=max_scale,
|
||||
min_scale=min_scale,
|
||||
scale_step=scale_step,
|
||||
step=step,
|
||||
)
|
||||
|
||||
for crop in crops:
|
||||
crop["score"] = self.score(score_image, crop)
|
||||
if crop["score"]["total"] > top_score:
|
||||
top_crop = crop
|
||||
top_score = crop["score"]["total"]
|
||||
|
||||
return {"analyse_image": analyse_image, "crops": crops, "top_crop": top_crop}
|
||||
|
||||
def crop(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
prescale=True,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
"""Not yet fully cleaned from https://github.com/hhatto/smartcrop.py."""
|
||||
scale = min(image.size[0] / width, image.size[1] / height)
|
||||
crop_width = int(math.floor(width * scale))
|
||||
crop_height = int(math.floor(height * scale))
|
||||
# img = 100x100, width = 95x95, scale = 100/95, 1/scale > min
|
||||
# don't set minscale smaller than 1/scale
|
||||
# -> don't pick crops that need upscaling
|
||||
min_scale = min(max_scale, max(1 / scale, min_scale))
|
||||
|
||||
prescale_size = 1
|
||||
if prescale:
|
||||
prescale_size = 1 / scale / min_scale
|
||||
if prescale_size < 1:
|
||||
image = image.copy()
|
||||
image.thumbnail(
|
||||
(
|
||||
int(image.size[0] * prescale_size),
|
||||
int(image.size[1] * prescale_size),
|
||||
),
|
||||
Image.ANTIALIAS,
|
||||
)
|
||||
crop_width = int(math.floor(crop_width * prescale_size))
|
||||
crop_height = int(math.floor(crop_height * prescale_size))
|
||||
else:
|
||||
prescale_size = 1
|
||||
|
||||
result = self.analyse(
|
||||
image,
|
||||
crop_width=crop_width,
|
||||
crop_height=crop_height,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
scale_step=scale_step,
|
||||
step=step,
|
||||
)
|
||||
|
||||
for i in range(len(result["crops"])):
|
||||
crop = result["crops"][i]
|
||||
crop["x"] = int(math.floor(crop["x"] / prescale_size))
|
||||
crop["y"] = int(math.floor(crop["y"] / prescale_size))
|
||||
crop["width"] = int(math.floor(crop["width"] / prescale_size))
|
||||
crop["height"] = int(math.floor(crop["height"] / prescale_size))
|
||||
result["crops"][i] = crop
|
||||
return result
|
||||
|
||||
def crops(
|
||||
self,
|
||||
image,
|
||||
crop_width,
|
||||
crop_height,
|
||||
max_scale=1,
|
||||
min_scale=0.9,
|
||||
scale_step=0.1,
|
||||
step=8,
|
||||
):
|
||||
image_width, image_height = image.size
|
||||
crops = []
|
||||
for scale in (
|
||||
i / 100
|
||||
for i in range(
|
||||
int(max_scale * 100),
|
||||
int((min_scale - scale_step) * 100),
|
||||
-int(scale_step * 100),
|
||||
)
|
||||
):
|
||||
for y in range(0, image_height, step):
|
||||
if not (y + crop_height * scale <= image_height):
|
||||
break
|
||||
for x in range(0, image_width, step):
|
||||
if not (x + crop_width * scale <= image_width):
|
||||
break
|
||||
crops.append(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": crop_width * scale,
|
||||
"height": crop_height * scale,
|
||||
}
|
||||
)
|
||||
if not crops:
|
||||
raise ValueError(locals())
|
||||
return crops
|
||||
|
||||
def debug_crop(self, analyse_image, crop):
|
||||
debug_image = analyse_image.copy()
|
||||
debug_pixels = debug_image.getdata()
|
||||
debug_crop_image = Image.new(
|
||||
"RGBA",
|
||||
(int(math.floor(crop["width"])), int(math.floor(crop["height"]))),
|
||||
(255, 0, 0, 25),
|
||||
)
|
||||
ImageDraw.Draw(debug_crop_image).rectangle(
|
||||
((0, 0), (crop["width"], crop["height"])), outline=(255, 0, 0)
|
||||
)
|
||||
|
||||
for y in range(analyse_image.size[1]): # height
|
||||
for x in range(analyse_image.size[0]): # width
|
||||
p = y * analyse_image.size[0] + x
|
||||
importance = self.importance(crop, x, y)
|
||||
if importance > 0:
|
||||
debug_pixels.putpixel(
|
||||
(x, y),
|
||||
(
|
||||
debug_pixels[p][0],
|
||||
int(debug_pixels[p][1] + importance * 32),
|
||||
debug_pixels[p][2],
|
||||
),
|
||||
)
|
||||
elif importance < 0:
|
||||
debug_pixels.putpixel(
|
||||
(x, y),
|
||||
(
|
||||
int(debug_pixels[p][0] + importance * -64),
|
||||
debug_pixels[p][1],
|
||||
debug_pixels[p][2],
|
||||
),
|
||||
)
|
||||
debug_image.paste(
|
||||
debug_crop_image, (crop["x"], crop["y"]), debug_crop_image.split()[3]
|
||||
)
|
||||
return debug_image
|
||||
|
||||
def detect_edge(self, cie_image):
|
||||
return cie_image.filter(Kernel((3, 3), (0, -1, 0, -1, 4, -1, 0, -1, 0), 1, 1))
|
||||
|
||||
def detect_saturation(self, cie_array, source_image):
|
||||
threshold = self.saturation_threshold
|
||||
saturation_data = saturation(source_image)
|
||||
mask = (
|
||||
(saturation_data > threshold)
|
||||
& (cie_array >= self.saturation_brightness_min * 255)
|
||||
& (cie_array <= self.saturation_brightness_max * 255)
|
||||
)
|
||||
|
||||
saturation_data[~mask] = 0
|
||||
saturation_data[mask] = (saturation_data[mask] - threshold) * (
|
||||
255 / (1 - threshold)
|
||||
)
|
||||
|
||||
return Image.fromarray(saturation_data.astype("uint8"))
|
||||
|
||||
def detect_skin(self, cie_array, source_image):
|
||||
r, g, b = source_image.split()
|
||||
r, g, b = np.array(r), np.array(g), np.array(b)
|
||||
r, g, b = r.astype(float), g.astype(float), b.astype(float)
|
||||
rd = np.ones_like(r) * -self.skin_color[0]
|
||||
gd = np.ones_like(g) * -self.skin_color[1]
|
||||
bd = np.ones_like(b) * -self.skin_color[2]
|
||||
|
||||
mag = np.sqrt(r * r + g * g + b * b)
|
||||
mask = ~(abs(mag) < 1e-6)
|
||||
rd[mask] = r[mask] / mag[mask] - self.skin_color[0]
|
||||
gd[mask] = g[mask] / mag[mask] - self.skin_color[1]
|
||||
bd[mask] = b[mask] / mag[mask] - self.skin_color[2]
|
||||
|
||||
skin = 1 - np.sqrt(rd * rd + gd * gd + bd * bd)
|
||||
mask = (
|
||||
(skin > self.skin_threshold)
|
||||
& (cie_array >= self.skin_brightness_min * 255)
|
||||
& (cie_array <= self.skin_brightness_max * 255)
|
||||
)
|
||||
|
||||
skin_data = (skin - self.skin_threshold) * (255 / (1 - self.skin_threshold))
|
||||
skin_data[~mask] = 0
|
||||
|
||||
return Image.fromarray(skin_data.astype("uint8"))
|
||||
|
||||
def importance(self, crop, x, y):
|
||||
if (
|
||||
crop["x"] > x
|
||||
or x >= crop["x"] + crop["width"]
|
||||
or crop["y"] > y
|
||||
or y >= crop["y"] + crop["height"]
|
||||
):
|
||||
return self.outside_importance
|
||||
|
||||
x = (x - crop["x"]) / crop["width"]
|
||||
y = (y - crop["y"]) / crop["height"]
|
||||
px, py = abs(0.5 - x) * 2, abs(0.5 - y) * 2
|
||||
|
||||
# distance from edge
|
||||
dx = max(px - 1 + self.edge_radius, 0)
|
||||
dy = max(py - 1 + self.edge_radius, 0)
|
||||
d = (dx * dx + dy * dy) * self.edge_weight
|
||||
s = 1.41 - math.sqrt(px * px + py * py)
|
||||
|
||||
if self.rule_of_thirds:
|
||||
s += (max(0, s + d + 0.5) * 1.2) * (thirds(px) + thirds(py))
|
||||
|
||||
return s + d
|
||||
|
||||
def score(self, target_image, crop):
|
||||
score = {
|
||||
"detail": 0,
|
||||
"saturation": 0,
|
||||
"skin": 0,
|
||||
"total": 0,
|
||||
}
|
||||
target_data = target_image.getdata()
|
||||
target_width, target_height = target_image.size
|
||||
|
||||
down_sample = self.score_down_sample
|
||||
inv_down_sample = 1 / down_sample
|
||||
target_width_down_sample = target_width * down_sample
|
||||
target_height_down_sample = target_height * down_sample
|
||||
|
||||
for y in range(0, target_height_down_sample, down_sample):
|
||||
for x in range(0, target_width_down_sample, down_sample):
|
||||
p = int(
|
||||
math.floor(y * inv_down_sample) * target_width
|
||||
+ math.floor(x * inv_down_sample)
|
||||
)
|
||||
importance = self.importance(crop, x, y)
|
||||
detail = target_data[p][1] / 255
|
||||
score["skin"] += (
|
||||
target_data[p][0] / 255 * (detail + self.skin_bias) * importance
|
||||
)
|
||||
score["detail"] += detail * importance
|
||||
score["saturation"] += (
|
||||
target_data[p][2]
|
||||
/ 255
|
||||
* (detail + self.saturation_bias)
|
||||
* importance
|
||||
)
|
||||
score["total"] = (
|
||||
score["detail"] * self.detail_weight
|
||||
+ score["skin"] * self.skin_weight
|
||||
+ score["saturation"] * self.saturation_weight
|
||||
) / (crop["width"] * crop["height"])
|
||||
return score
|
12
tests/enhancers/test_facecrop.py
Normal file
12
tests/enhancers/test_facecrop.py
Normal file
@ -0,0 +1,12 @@
|
||||
import logging
|
||||
|
||||
from imaginairy import LazyLoadingImage
|
||||
from imaginairy.enhancers.facecrop import generate_face_crops
|
||||
from tests import TESTS_FOLDER
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def test_facecrop():
|
||||
img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png")
|
||||
generate_face_crops((50, 50, 150, 150), max_width=img.width, max_height=img.height)
|
@ -60,7 +60,7 @@ def test_clip_masking(filename_base_for_outputs):
|
||||
|
||||
result = next(imagine(prompt))
|
||||
img_path = f"{filename_base_for_outputs}.png"
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1000)
|
||||
assert_image_similar_to_expectation(result.img, img_path=img_path, threshold=1100)
|
||||
|
||||
|
||||
boolean_mask_test_cases = [
|
||||
|
69
tests/test_roi_utils.py
Normal file
69
tests/test_roi_utils.py
Normal file
@ -0,0 +1,69 @@
|
||||
import itertools
|
||||
import random
|
||||
|
||||
from imaginairy.roi_utils import (
|
||||
RoiNotInBoundsError,
|
||||
resize_roi_coordinates,
|
||||
square_roi_coordinate,
|
||||
)
|
||||
|
||||
|
||||
def test_square_roi_coordinate():
|
||||
img_sizes = (10, 100, 200, 511, 513, 1024)
|
||||
# iterate through all permutations of image sizes using itertools.product
|
||||
for img_width, img_height in itertools.product(img_sizes, img_sizes):
|
||||
# randomly generate a region of interest
|
||||
for _ in range(100):
|
||||
x1 = random.randint(0, img_width)
|
||||
y1 = random.randint(0, img_height)
|
||||
x2 = random.randint(x1, img_width)
|
||||
y2 = random.randint(y1, img_height)
|
||||
roi = x1, y1, x2, y2
|
||||
try:
|
||||
x1, y1, x2, y2 = square_roi_coordinate(roi, img_width, img_height)
|
||||
except RoiNotInBoundsError:
|
||||
continue
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
||||
|
||||
|
||||
# resize_roi_coordinates
|
||||
|
||||
|
||||
def test_square_resize_roi_coordinates():
|
||||
img_sizes = (10, 100, 200, 403, 511, 513, 604, 1024)
|
||||
# iterate through all permutations of image sizes using itertools.product
|
||||
img_sizes = list(itertools.product(img_sizes, img_sizes))
|
||||
|
||||
for img_width, img_height in img_sizes:
|
||||
# randomly generate a region of interest
|
||||
rois = []
|
||||
for _ in range(100):
|
||||
x1 = random.randint(0 + 1, img_width - 1)
|
||||
y1 = random.randint(0 + 1, img_height - 1)
|
||||
x2 = random.randint(x1 + 1, img_width)
|
||||
y2 = random.randint(y1 + 1, img_height)
|
||||
roi = x1, y1, x2, y2
|
||||
rois.append(roi)
|
||||
rois.append((392, 85, 695, 389))
|
||||
for roi in rois:
|
||||
try:
|
||||
squared_roi = square_roi_coordinate(roi, img_width, img_height)
|
||||
except RoiNotInBoundsError:
|
||||
continue
|
||||
for n in range(10):
|
||||
factor = 1.25 + 0.3 * n
|
||||
x1, y1, x2, y2 = resize_roi_coordinates(
|
||||
squared_roi, factor, img_width, img_height
|
||||
)
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
||||
|
||||
x1, y1, x2, y2 = resize_roi_coordinates(
|
||||
squared_roi, factor, img_width, img_height, expand_up=False
|
||||
)
|
||||
assert (
|
||||
x2 - x1 == y2 - y1
|
||||
), f"ROI is not square: img_width: {img_width}, img_height: {img_height}, roi: {roi}"
|
Loading…
Reference in New Issue
Block a user