perf: improve startup time by doing some imports lazily (#233)

just running `aimg --help` or `aimg --version` was very slow due to all the imports being brought in eagerly

Before changes `aimg --help`
`2.24s user 4.05s system 184% cpu 3.416 total`

After changes:
`0.04s user 0.02s system 8% cpu 0.625 total`

Used `PYTHONPROFILEIMPORTTIME=1 aimg --help` to find time consuming imports.

Also switched to using `scripts` instead of `entrypoints` since the scripts are much faster.

Made duplicate SAMPLER_TYPE_OPTIONS that can be loaded without loading all the samplers themselves.

Likely a breaking change - not sure.
pull/237/head
Bryce Drennan 1 year ago committed by GitHub
parent b611a92b49
commit 9eacf5e7ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ jobs:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --disable-pip-version-check black isort
python -m pip install --disable-pip-version-check black==22.12.0 isort==5.11.4
- name: Autoformatter
run: |
black --diff .

@ -291,6 +291,8 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
- perf: cli now has minimal overhead such that `aimg --help` runs in ~650ms instead of ~3400ms
**8.3.1**
- fix: init-image-strength type

@ -1,19 +1,12 @@
import os.path
import os
# tells pytorch to allow MPS usage (for Mac M1 compatibility)
os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1")
import PIL.Image # noqa
from .api import imagine, imagine_image_files # noqa
from .enhancers.describe_image_blip import generate_caption # noqa
from .schema import ( # noqa
ImaginePrompt,
ImagineResult,
LazyLoadingImage,
WeightedPrompt,
)
from .version import __version__ # noqa
# https://stackoverflow.com/questions/71738218/module-pil-has-not-attribute-resampling
if not hasattr(PIL.Image, "Resampling"): # Pillow<9.0
PIL.Image.Resampling = PIL.Image

@ -2,41 +2,7 @@ import logging
import os
import re
import numpy as np
import torch
import torch.nn
from einops import rearrange, repeat
from PIL import Image, ImageDraw, ImageOps
from pytorch_lightning import seed_everything
from torch.cuda import OutOfMemoryError
from imaginairy.animations import make_bounce_animation
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.log_utils import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.modules.midas.utils import AddMiDaS
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import SafetyMode, create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers.base import NoiseSchedule, noise_an_image
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
platform_appropriate_autocast,
randn_seeded,
)
from imaginairy.schema import SafetyMode
logger = logging.getLogger(__name__)
@ -64,6 +30,11 @@ def imagine_image_files(
make_compare_gif=False,
return_filename_type="generated",
):
from PIL import ImageDraw
from imaginairy.animations import make_bounce_animation
from imaginairy.img_utils import pillow_fit_image_within
generated_imgs_path = os.path.join(outdir, "generated")
os.makedirs(generated_imgs_path, exist_ok=True)
@ -167,6 +138,16 @@ def imagine(
add_caption=False,
unsafe_retry_count=1,
):
import torch.nn
from imaginairy.schema import ImaginePrompt
from imaginairy.utils import (
fix_torch_group_norm,
fix_torch_nn_layer_norm,
get_device,
platform_appropriate_autocast,
)
prompts = [ImaginePrompt(prompts)] if isinstance(prompts, str) else prompts
prompts = [prompts] if isinstance(prompts, ImaginePrompt) else prompts
@ -214,6 +195,34 @@ def _generate_single_image(
half_mode=None,
add_caption=False,
):
import numpy as np
import torch.nn
from einops import rearrange, repeat
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from torch.cuda import OutOfMemoryError
from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
from imaginairy.enhancers.face_restoration_codeformer import enhance_faces
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.img_utils import pillow_fit_image_within, pillow_img_to_torch_image
from imaginairy.log_utils import (
ImageLoggingContext,
log_conditioning,
log_img,
log_latent,
)
from imaginairy.model_manager import get_diffusion_model
from imaginairy.modules.midas.utils import AddMiDaS
from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint
from imaginairy.safety import create_safety_score
from imaginairy.samplers import SAMPLER_LOOKUP
from imaginairy.samplers.base import NoiseSchedule, noise_an_image
from imaginairy.samplers.editing import CFGEditingDenoiser
from imaginairy.schema import ImaginePrompt, ImagineResult
from imaginairy.utils import get_device, randn_seeded
latent_channels = 4
downsampling_factor = 8
batch_size = 1

@ -0,0 +1,6 @@
#!/bin/env python
# -*- coding: utf-8 -*-
if __name__ == "__main__":
from imaginairy import cmds
cmds.aimg()

@ -0,0 +1,6 @@
#!/bin/env python
# -*- coding: utf-8 -*-
if __name__ == "__main__":
from imaginairy import cmds
cmds.imagine_cmd()

@ -1,29 +1,10 @@
import logging
import math
import os.path
import click
from click_shell import shell
from tqdm import tqdm
from imaginairy import LazyLoadingImage, __version__, config, generate_caption
from imaginairy.animations import make_bounce_animation
from imaginairy.api import imagine_image_files
from imaginairy.debug_info import get_debug_info
from imaginairy.enhancers.prompt_expansion import expand_prompts
from imaginairy.enhancers.upscale_realesrgan import upscale_image
from imaginairy.log_utils import configure_logging
from imaginairy.prompt_schedules import parse_schedule_strs, prompt_mutator
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
from imaginairy.schema import ImaginePrompt
from imaginairy.surprise_me import create_surprise_me_images
from imaginairy.train import train_diffusion_model
from imaginairy.training_tools.image_prep import (
create_class_images,
get_image_filenames,
prep_images,
)
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
from imaginairy import config
logger = logging.getLogger(__name__)
@ -111,7 +92,7 @@ logger = logging.getLogger(__name__)
"--sampler",
default=config.DEFAULT_SAMPLER,
show_default=True,
type=click.Choice(SAMPLER_TYPE_OPTIONS),
type=click.Choice(config.SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use.",
)
@click.option(
@ -413,7 +394,7 @@ def imagine_cmd(
"--sampler",
default=config.DEFAULT_SAMPLER,
show_default=True,
type=click.Choice(SAMPLER_TYPE_OPTIONS),
type=click.Choice(config.SAMPLER_TYPE_OPTIONS),
help="What sampling strategy to use.",
)
@click.option(
@ -603,6 +584,9 @@ def edit_image( # noqa
arg_schedules,
make_compilation_animation,
):
from imaginairy.log_utils import configure_logging
from imaginairy.surprise_me import create_surprise_me_images
init_image_strength = 1
if surprise_me and prompt_texts:
raise ValueError("Cannot use surprise_me and prompt_texts together")
@ -698,11 +682,23 @@ def _imagine_cmd(
make_compilation_animation=False,
):
"""Have the AI generate images. alias:imagine."""
import os.path
from imaginairy import LazyLoadingImage
from imaginairy.animations import make_bounce_animation
from imaginairy.api import imagine_image_files
from imaginairy.enhancers.prompt_expansion import expand_prompts
from imaginairy.log_utils import configure_logging
from imaginairy.prompt_schedules import parse_schedule_strs, prompt_mutator
from imaginairy.schema import ImaginePrompt
if ctx.invoked_subcommand is not None:
return
if version:
print(__version__)
from imaginairy.version import get_version
print(get_version())
return
if quiet:
@ -816,13 +812,14 @@ def aimg():
Run `aimg` to start a persistent shell session. This makes generation and editing much
quicker since the model can stay loaded in memory.
"""
configure_logging()
@aimg.command()
def version():
"""Print the version."""
print(__version__)
from imaginairy.version import get_version
print(get_version())
@click.argument("image_filepaths", nargs=-1)
@ -838,6 +835,13 @@ def upscale_cmd(image_filepaths, outdir):
"""
Upscale an image 4x using AI.
"""
import os.path
from tqdm import tqdm
from imaginairy import LazyLoadingImage
from imaginairy.enhancers.upscale_realesrgan import upscale_image
os.makedirs(outdir, exist_ok=True)
for p in tqdm(image_filepaths):
@ -847,7 +851,7 @@ def upscale_cmd(image_filepaths, outdir):
else:
img = LazyLoadingImage(filepath=p)
logger.info(
f"Upscaling {p} from {img.width}x{img.height } to {img.width * 4}x{img.height*4} and saving it to {savepath}"
f"Upscaling {p} from {img.width}x{img.height} to {img.width * 4}x{img.height * 4} and saving it to {savepath}"
)
img = upscale_image(img)
@ -859,6 +863,10 @@ def upscale_cmd(image_filepaths, outdir):
@aimg.command()
def describe(image_filepaths):
"""Generate text descriptions of images."""
from imaginairy import LazyLoadingImage
from imaginairy.enhancers.describe_image_blip import generate_caption
imgs = []
for p in image_filepaths:
if p.startswith("http"):
@ -991,6 +999,15 @@ def train_concept(
You can find a lot of relevant instructions here: https://github.com/JoePenna/Dreambooth-Stable-Diffusion
"""
import os.path
from imaginairy.train import train_diffusion_model
from imaginairy.training_tools.image_prep import (
create_class_images,
get_image_filenames,
prep_images,
)
target_size = 512
# Step 1. Crop and enhance the training images
prepped_images_path = os.path.join(concept_images_dir, "prepped-images")
@ -1081,7 +1098,9 @@ def prepare_images(images_dir, is_person, target_size):
aimg prep-images --person ./images/selfies
aimg prep-images ./images/toy-train
"""
configure_logging()
from imaginairy.training_tools.image_prep import prep_images
prep_images(images_dir=images_dir, is_person=is_person, target_size=target_size)
@ -1097,8 +1116,9 @@ def prune_ckpt(ckpt_paths):
Example:
aimg prune-ckpt ./path/to/checkpoint.ckpt
"""
from imaginairy.training_tools.prune_model import prune_diffusion_ckpt
click.secho("Pruning checkpoint files...")
configure_logging()
for p in ckpt_paths:
prune_diffusion_ckpt(p)
@ -1108,6 +1128,8 @@ def system_info():
"""
Display system information. Submit this when reporting bugs.
"""
from imaginairy.debug_info import get_debug_info
for k, v in get_debug_info().items():
k += ":"
click.secho(f"{k: <30} {v}")

@ -142,3 +142,18 @@ MODEL_CONFIG_SHORTCUTS["openjourney"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_CONFIG_SHORTCUTS["oj"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"]
MODEL_SHORT_NAMES = sorted(MODEL_CONFIG_SHORTCUTS.keys())
SAMPLER_TYPE_OPTIONS = [
"plms",
"ddim",
"k_dpm_fast",
"k_dpm_adaptive",
"k_lms",
"k_dpm_2",
"k_dpm_2_a",
"k_dpmpp_2m",
"k_dpmpp_2s_a",
"k_euler",
"k_euler_a",
"k_heun",
]

@ -3,13 +3,14 @@ import sys
import torch
from imaginairy import __version__
from imaginairy.utils import get_device, get_hardware_description
from imaginairy.version import get_version
def get_debug_info():
data = {
"imaginairy_version": __version__,
"imaginairy_version": get_version(),
"imaginairy_path": os.path.dirname(__file__),
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"python_installation_path": sys.executable,

@ -4,12 +4,6 @@ import re
import time
import warnings
import torch
from pytorch_lightning import _logger as pytorch_logger
from torchvision.transforms import ToPILImage
from transformers.modeling_utils import logger as modeling_logger
from transformers.utils.logging import _configure_library_root_logger
_CURRENT_LOGGING_CONTEXT = None
logger = logging.getLogger(__name__)
@ -159,6 +153,9 @@ class ImageLoggingContext:
def log_img(self, img, description):
if not self.debug_img_callback:
return
import torch
from torchvision.transforms import ToPILImage
self.image_count += 1
if isinstance(img, torch.Tensor):
img = ToPILImage()(img.squeeze().cpu().detach())
@ -200,6 +197,8 @@ def filesafe_text(t):
def conditioning_to_img(conditioning):
from torchvision.transforms import ToPILImage
return ToPILImage()(conditioning)
@ -252,6 +251,9 @@ def configure_logging(level="INFO"):
def disable_transformers_custom_logging():
from transformers.modeling_utils import logger as modeling_logger
from transformers.utils.logging import _configure_library_root_logger
_configure_library_root_logger()
_logger = modeling_logger.parent
_logger.handlers = []
@ -263,6 +265,8 @@ def disable_transformers_custom_logging():
def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger
try:
from pytorch_lightning.utilities.seed import log # noqa

@ -6,15 +6,11 @@ from diffusers.pipelines.stable_diffusion import safety_checker as safety_checke
from transformers import AutoFeatureExtractor
from imaginairy.enhancers.blur_detect import is_blurry
from imaginairy.schema import SafetyMode
logger = logging.getLogger(__name__)
class SafetyMode:
STRICT = "strict"
RELAXED = "relaxed"
class SafetyResult:
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images

@ -4,17 +4,8 @@ import logging
import os.path
import random
from datetime import datetime, timezone
from functools import lru_cache
import requests
from PIL import Image, ImageOps
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
from imaginairy import config
from imaginairy.model_manager import get_model_default_image_size
from imaginairy.samplers import SAMPLER_LOOKUP, SamplerName
from imaginairy.utils import get_device, get_hardware_description
logger = logging.getLogger(__name__)
@ -36,6 +27,9 @@ class LazyLoadingImage:
# validate url is valid url
if url:
from urllib3.exceptions import LocationParseError
from urllib3.util import parse_url
try:
parsed_url = parse_url(url)
except LocationParseError:
@ -53,6 +47,7 @@ class LazyLoadingImage:
raise AttributeError()
if self._img:
return getattr(self._img, key)
from PIL import Image, ImageOps
if self._lazy_filepath:
self._img = Image.open(self._lazy_filepath)
@ -60,6 +55,8 @@ class LazyLoadingImage:
f"Loaded input 🖼 of size {self._img.size} from {self._lazy_filepath}"
)
elif self._lazy_url:
import requests
self._img = Image.open(
requests.get(self._lazy_url, stream=True, timeout=60).raw
)
@ -148,6 +145,8 @@ class ImaginePrompt:
self.validate()
def validate(self):
from imaginairy.samplers import SAMPLER_LOOKUP, SamplerName
self.prompts = self.process_prompt_input(self.prompts)
if self.tile_mode is True:
@ -182,6 +181,8 @@ class ImaginePrompt:
)
if self.height is None or self.width is None or self.steps is None:
from imaginairy.model_manager import get_model_default_image_size
SamplerCls = SAMPLER_LOOKUP[self.sampler_type]
self.steps = self.steps or SamplerCls.default_steps
self.width = self.width or get_model_default_image_size(self.model)
@ -281,6 +282,8 @@ class ImagineResult:
timings=None,
progress_latents=None,
):
from imaginairy.utils import get_device, get_hardware_description
self.prompt = prompt
self.images = {"generated": img}
@ -327,6 +330,8 @@ class ImagineResult:
return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items())
def _exif(self):
from PIL import Image
exif = Image.Exif()
exif[ExifCodes.ImageDescription] = self.prompt.prompt_description()
exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict())
@ -349,6 +354,6 @@ class ImagineResult:
img.convert("RGB").save(save_path, exif=self._exif())
@lru_cache(maxsize=2)
def _get_briefly_cached_url(url):
return requests.get(url, timeout=60)
class SafetyMode:
STRICT = "strict"
RELAXED = "relaxed"

@ -1,6 +1,7 @@
from importlib.metadata import PackageNotFoundError, version
def get_version():
from importlib.metadata import PackageNotFoundError, version
try:
__version__ = version("imaginairy")
except PackageNotFoundError:
__version__ = None
try:
return version("imaginairy")
except PackageNotFoundError:
return None

@ -16,16 +16,12 @@ setup(
"Source": "https://github.com/brycedrennan/imaginAIry",
},
packages=find_packages(include=("imaginairy", "imaginairy.*")),
entry_points={
"console_scripts": [
"imagine=imaginairy.cmds:imagine_cmd",
"aimg=imaginairy.cmds:aimg",
],
},
scripts=["imaginairy/bin/aimg", "imaginairy/bin/imagine"],
package_data={
"imaginairy": [
"configs/*.yaml",
"data/*.*",
"bin/*.*",
"enhancers/phraselists/*.txt",
"vendored/clip/*.txt.gz",
"vendored/clipseg/*.pth",

@ -0,0 +1,6 @@
from imaginairy import config
from imaginairy.samplers import SAMPLER_TYPE_OPTIONS
def test_sampler_options():
assert set(config.SAMPLER_TYPE_OPTIONS) == set(SAMPLER_TYPE_OPTIONS)

@ -11,7 +11,7 @@ format = pylint
skip = */.tox/*,*/.env/*,build/*,*/downloads/*,other/*,prolly_delete/*,downloads/*,imaginairy/vendored/*,testing_support/vastai_cli_official.py,.eggs/*
linters = pylint,pycodestyle,pyflakes,mypy
ignore =
Z999,C0103,C0301,C0302,C0114,C0115,C0116,
Z999,C0103,C0301,C0302,C0114,C0115,C0116,C0415,
Z999,D100,D101,D102,D103,D105,D106,D107,D200,D202,D203,D205,D212,D400,D401,D406,D407,D413,D415,D417,
Z999,E203,E501,E1101,E1131,E1135,E1136,
Z999,R0901,R0902,R0903,R0904,R0193,R0912,R0913,R0914,R0915,R1702,

Loading…
Cancel
Save