feature: improvements to memory management

not thoroughly tested on low-memory devices
pull/333/head
Bryce 1 year ago committed by Bryce Drennan
parent 6db296aa37
commit 4c77fd376b

@ -468,6 +468,11 @@ A: The AI models are cached in `~/.cache/` (or `HUGGINGFACE_HUB_CACHE`). To dele
## ChangeLog
- feature: multi-controlnet support
- feature: "better" memory management. If GPU is full, least-recently-used model is moved to RAM.
- fix: hide the "triton" error messages
- feature: show full stack trace on error in cli
**12.0.3**
- fix: exclude broken versions of timm as dependencies

@ -165,7 +165,7 @@ def imagine(
), fix_torch_nn_layer_norm(), fix_torch_group_norm():
for i, prompt in enumerate(prompts):
logger.info(
f"Generating 🖼 {i + 1}/{num_prompts}: {prompt.prompt_description()}"
f"🖼 Generating {i + 1}/{num_prompts}: {prompt.prompt_description()}"
)
for attempt in range(0, unsafe_retry_count + 1):
if attempt > 0:
@ -246,7 +246,9 @@ def _generate_single_image(
model = get_diffusion_model(
weights_location=prompt.model,
config_path=prompt.model_config_path,
control_weights_location=prompt.control_mode,
control_weights_locations=[prompt.control_mode]
if prompt.control_mode
else None,
half_mode=half_mode,
for_inpainting=(prompt.mask_image or prompt.mask_prompt or prompt.outpaint)
and not suppress_inpaint,
@ -462,7 +464,7 @@ def _generate_single_image(
c_cat_neutral = [torch.zeros_like(init_latent)]
denoiser_cls = CFGEditingDenoiser
if c_cat:
c_cat = [torch.cat(c_cat, dim=1)]
c_cat = [torch.cat([c], dim=1) for c in c_cat]
if c_cat_neutral is None:
c_cat_neutral = c_cat

@ -0,0 +1,91 @@
"""Most of these modifications are just so we get full stack traces in the shell."""
import logging
import shlex
import traceback
from functools import update_wrapper
import click
from click_help_colors import HelpColorsCommand, HelpColorsMixin
from click_shell import Shell
from click_shell._compat import get_method_type
from click_shell.core import ClickShell, get_complete, get_help
logger = logging.getLogger(__name__)
def mod_get_invoke(command):
"""
Get the Cmd main method from the click command
:param command: The click Command object
:return: the do_* method for Cmd
:rtype: function.
"""
assert isinstance(command, click.Command)
def invoke_(self, arg): # pylint: disable=unused-argument
try:
command.main(
args=shlex.split(arg),
prog_name=command.name,
standalone_mode=False,
parent=self.ctx,
)
except click.ClickException as e:
# Show the error message
e.show()
except click.Abort:
# We got an EOF or Keyboard interrupt. Just silence it
pass
except SystemExit:
# Catch this an return the code instead. All of click's help commands do a sys.exit(),
# and that's not ideal when running in a shell.
pass
except Exception as e: # noqa
traceback.print_exception(e)
# logger.warning(traceback.format_exc())
# Always return False so the shell doesn't exit
return False
invoke_ = update_wrapper(invoke_, command.callback)
invoke_.__name__ = "do_%s" % command.name # noqa
return invoke_
class ModClickShell(ClickShell):
def add_command(self, cmd, name):
# Use the MethodType to add these as bound methods to our current instance
setattr(
self, "do_%s" % name, get_method_type(mod_get_invoke(cmd), self) # noqa
)
setattr(self, "help_%s" % name, get_method_type(get_help(cmd), self)) # noqa
setattr(
self, "complete_%s" % name, get_method_type(get_complete(cmd), self) # noqa
)
class ModShell(Shell):
def __init__(
self, prompt=None, intro=None, hist_file=None, on_finished=None, **attrs
):
attrs["invoke_without_command"] = True
super(Shell, self).__init__(**attrs)
# Make our shell
self.shell = ModClickShell(hist_file=hist_file, on_finished=on_finished)
if prompt:
self.shell.prompt = prompt
self.shell.intro = intro
class ColorShell(HelpColorsMixin, ModShell):
pass
class ImagineColorsCommand(HelpColorsCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.help_headers_color = "yellow"
self.help_options_color = "green"

@ -1,11 +1,7 @@
import click
from imaginairy.cli.shared import (
ImagineColorsCommand,
_imagine_cmd,
add_options,
common_options,
)
from imaginairy.cli.clickshell_mod import ImagineColorsCommand
from imaginairy.cli.shared import _imagine_cmd, add_options, common_options
@click.command(

@ -2,12 +2,12 @@ import logging
import click
from imaginairy.cli.clickshell_mod import ColorShell, ImagineColorsCommand
from imaginairy.cli.colorize import colorize_cmd
from imaginairy.cli.describe import describe_cmd
from imaginairy.cli.edit import edit_cmd
from imaginairy.cli.edit_demo import edit_demo_cmd
from imaginairy.cli.imagine import imagine_cmd
from imaginairy.cli.shared import ColorShell, ImagineColorsCommand
from imaginairy.cli.train import prep_images_cmd, prune_ckpt_cmd, train_concept_cmd
from imaginairy.cli.upscale import upscale_cmd

@ -2,8 +2,6 @@ import logging
import math
import click
from click_help_colors import HelpColorsCommand, HelpColorsMixin
from click_shell import Shell
from imaginairy import config
@ -497,14 +495,3 @@ common_options = [
type=str,
),
]
class ColorShell(HelpColorsMixin, Shell):
pass
class ImagineColorsCommand(HelpColorsCommand):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.help_headers_color = "yellow"
self.help_options_color = "green"

@ -1,16 +1,15 @@
from functools import lru_cache
import numpy as np
import torch
from PIL import Image
from imaginairy.model_manager import get_cached_url_path
from imaginairy.utils import get_device
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer
@lru_cache()
@memory_managed_model("realesrgan_upsampler")
def realesrgan_upsampler():
model = RRDBNet(
num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4

@ -247,6 +247,9 @@ def configure_logging(level="INFO"):
"level": "ERROR",
"propagate": False,
},
# disable the stupid triton is not available messages
# https://github.com/facebookresearch/xformers/blob/6425fd0cacb1a6579aa2f0c4a570b737cb10e9c3/xformers/__init__.py#L52
"xformers": {"handlers": ["default"], "level": "ERROR", "propagate": False},
},
}
suppress_annoying_logs_and_warnings()

@ -1,4 +1,3 @@
import gc
import logging
import os
import re
@ -19,11 +18,10 @@ from imaginairy.config import MODEL_SHORT_NAMES
from imaginairy.modules import attention
from imaginairy.paths import PKG_ROOT
from imaginairy.utils import get_device, instantiate_from_config
from imaginairy.utils.model_cache import memory_managed_model
logger = logging.getLogger(__name__)
LOADED_MODELS = {}
MOST_RECENTLY_LOADED_MODEL = None
@ -31,82 +29,6 @@ class HuggingFaceAuthorizationError(RuntimeError):
pass
class MemoryAwareModel:
"""Wraps a model to allow dynamic loading/unloading as needed."""
def __init__(
self,
config_path,
weights_path,
control_weights_path=None,
half_mode=None,
for_training=False,
):
self._config_path = config_path
self._weights_path = weights_path
self._control_weights_path = control_weights_path
self._half_mode = half_mode
self._model = None
self._for_training = for_training
LOADED_MODELS[
(self._config_path, self._weights_path, self._control_weights_path)
] = self
def __getattr__(self, key):
if key == "_model":
# http://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html
raise AttributeError()
if self._model is None:
# unload all models in LOADED_MODELS
for model in LOADED_MODELS.values():
model.unload_model()
model_config = OmegaConf.load(f"{PKG_ROOT}/{self._config_path}")
if self._for_training:
model_config.use_ema = True
# model_config.use_scheduler = True
# only run half-mode on cuda. run it by default
half_mode = self._half_mode is None and get_device() == "cuda"
model = load_model_from_config(
config=model_config,
weights_location=self._weights_path,
control_weights_location=self._control_weights_path,
half_mode=half_mode,
)
ks = 128
stride = 64
vqf = 8
model.split_input_params = {
"ks": (ks, ks),
"stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01,
}
self._model = model
return getattr(self._model, key)
def unload_model(self):
if self._model is not None:
del self._model.cond_stage_model
del self._model.first_stage_model
del self._model.model
del self._model
self._model = None
if get_device() == "cuda":
torch.cuda.empty_cache()
gc.collect()
def load_tensors(tensorfile, map_location=None):
if tensorfile == "empty":
# used for testing
@ -118,13 +40,20 @@ def load_tensors(tensorfile, map_location=None):
raise ValueError(f"Unknown tensorfile type: {tensorfile}")
def load_state_dict(weights_location):
def load_state_dict(weights_location, half_mode=False, device=None):
if device is None:
device = get_device()
if weights_location.startswith("http"):
ckpt_path = get_cached_url_path(weights_location, category="weights")
else:
ckpt_path = weights_location
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
state_dict = None
# weights_cache_key = (ckpt_path, half_mode)
# if weights_cache_key in GLOBAL_WEIGHTS_CACHE:
# return GLOBAL_WEIGHTS_CACHE.get(weights_cache_key)
try:
state_dict = load_tensors(ckpt_path, map_location="cpu")
except FileNotFoundError as e:
@ -145,24 +74,60 @@ def load_state_dict(weights_location):
raise e
state_dict = state_dict.get("state_dict", state_dict)
if half_mode:
state_dict = {k: v.half() for k, v in state_dict.items()}
# change device
state_dict = {k: v.to(device) for k, v in state_dict.items()}
# GLOBAL_WEIGHTS_CACHE.set(weights_cache_key, state_dict)
return state_dict
def load_model_from_config(
config, weights_location, control_weights_location=None, half_mode=False
):
state_dict = load_state_dict(weights_location)
if control_weights_location:
controlnet_state_dict = load_state_dict(control_weights_location)
state_dict = add_controlnet(state_dict, controlnet_state_dict)
def load_model_from_config(config, weights_location, half_mode=False):
model = instantiate_from_config(config.model)
base_model_dict = load_state_dict(weights_location, half_mode=half_mode)
model.init_from_state_dict(base_model_dict)
if half_mode:
model = model.half()
model.to(get_device())
model.eval()
return model
def load_model_from_config_old(
config, weights_location, control_weights_locations=None, half_mode=False
):
model = instantiate_from_config(config.model)
model.init_from_state_dict(state_dict)
print("instantiated")
base_model_dict = load_state_dict(weights_location, half_mode=half_mode)
model.init_from_state_dict(base_model_dict)
control_weights_locations = control_weights_locations or []
controlnets = []
for control_weights_location in control_weights_locations:
controlnet_state_dict = load_state_dict(
control_weights_location, half_mode=half_mode
)
controlnet_state_dict = {
k.replace("control_model.", ""): v for k, v in controlnet_state_dict.items()
}
controlnet = instantiate_from_config(model.control_stage_config)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(get_device())
controlnets.append(controlnet)
model.set_control_models(controlnets)
if half_mode:
model = model.half()
print("halved")
model.to(get_device())
print("moved to device")
model.eval()
print("set to eval mode")
return model
@ -176,7 +141,7 @@ def add_controlnet(base_state_dict, controlnet_state_dict):
def get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL,
config_path="configs/stable-diffusion-v1.yaml",
control_weights_location=None,
control_weights_locations=None,
half_mode=None,
for_inpainting=False,
for_training=False,
@ -192,7 +157,7 @@ def get_diffusion_model(
config_path,
half_mode,
for_inpainting,
control_weights_location=control_weights_location,
control_weights_locations=control_weights_locations,
for_training=for_training,
)
except HuggingFaceAuthorizationError as e:
@ -206,7 +171,7 @@ def get_diffusion_model(
half_mode,
for_inpainting=False,
for_training=for_training,
control_weights_location=control_weights_location,
control_weights_locations=control_weights_locations,
)
raise e
@ -217,23 +182,22 @@ def _get_diffusion_model(
half_mode=None,
for_inpainting=False,
for_training=False,
control_weights_location=None,
control_weights_locations=None,
):
"""
Load a diffusion model.
Weights location may also be shortcut name, e.g. "SD-1.5"
"""
global MOST_RECENTLY_LOADED_MODEL # noqa
(
model_config,
weights_location,
config_path,
control_weights_location,
control_weights_locations,
) = resolve_model_paths(
weights_path=weights_location,
config_path=config_path,
control_weights_path=control_weights_location,
control_weights_paths=control_weights_locations,
for_inpainting=for_inpainting,
for_training=for_training,
)
@ -242,39 +206,74 @@ def _get_diffusion_model(
attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision
else:
attention.ATTENTION_PRECISION_OVERRIDE = "default"
diffusion_model = _load_diffusion_model(
config_path=config_path,
weights_location=weights_location,
half_mode=half_mode,
for_training=for_training,
)
if control_weights_locations:
controlnets = []
for control_weights_location in control_weights_locations:
controlnets.append(load_controlnet(control_weights_location, half_mode))
diffusion_model.set_control_models(controlnets)
key = (config_path, weights_location, control_weights_location)
if key not in LOADED_MODELS:
MemoryAwareModel(
config_path=config_path,
weights_path=weights_location,
control_weights_path=control_weights_location,
half_mode=half_mode,
for_training=for_training,
)
return diffusion_model
model = LOADED_MODELS[key]
# calling model attribute forces it to load
model.num_timesteps_cond # noqa
MOST_RECENTLY_LOADED_MODEL = model
@memory_managed_model("stable-diffusion", memory_usage_mb=1951)
def _load_diffusion_model(config_path, weights_location, half_mode, for_training):
model_config = OmegaConf.load(f"{PKG_ROOT}/{config_path}")
if for_training:
model_config.use_ema = True
# model_config.use_scheduler = True
# only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device() == "cuda"
model = load_model_from_config(
config=model_config,
weights_location=weights_location,
half_mode=half_mode,
)
return model
@memory_managed_model("controlnet")
def load_controlnet(control_weights_location, half_mode):
controlnet_state_dict = load_state_dict(
control_weights_location, half_mode=half_mode
)
controlnet_state_dict = {
k.replace("control_model.", ""): v for k, v in controlnet_state_dict.items()
}
control_stage_config = OmegaConf.load(f"{PKG_ROOT}/configs/control-net-v15.yaml")[
"model"
]["params"]["control_stage_config"]
controlnet = instantiate_from_config(control_stage_config)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(get_device())
return controlnet
def resolve_model_paths(
weights_path=iconfig.DEFAULT_MODEL,
config_path=None,
control_weights_path=None,
control_weights_paths=None,
for_inpainting=False,
for_training=False,
):
"""Resolve weight and config path if they happen to be shortcuts."""
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None)
model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None)
control_net_metadata = iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(
control_weights_path, None
)
if not control_net_metadata and for_inpainting:
control_weights_paths = control_weights_paths or []
control_net_metadatas = [
iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(control_weights_path, None)
for control_weights_path in control_weights_paths
]
if not control_net_metadatas and for_inpainting:
model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(
f"{weights_path}-inpaint", model_metadata_w
)
@ -299,17 +298,17 @@ def resolve_model_paths(
if config_path is None:
config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path
if control_net_metadata:
if control_net_metadatas:
if "stable-diffusion-v1" not in config_path:
raise ValueError(
"Control net is only supported for stable diffusion v1. Please use a different model."
)
control_weights_path = control_net_metadata.weights_url
config_path = control_net_metadata.config_path
control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas]
config_path = control_net_metadatas[0].config_path
model_metadata = model_metadata_w or model_metadata_c
logger.debug(f"Loading model weights from: {weights_path}")
logger.debug(f"Loading model config from: {config_path}")
return model_metadata, weights_path, config_path, control_weights_path
return model_metadata, weights_path, config_path, control_weights_paths
def get_model_default_image_size(weights_location):

@ -1,11 +1,8 @@
import einops
import torch
from einops import rearrange, repeat
from torch import nn
from torchvision.utils import make_grid
from imaginairy.modules.attention import SpatialTransformer
from imaginairy.modules.diffusion.ddpm import LatentDiffusion, log_txt_as_img
from imaginairy.modules.diffusion.ddpm import LatentDiffusion
from imaginairy.modules.diffusion.openaimodel import (
AttentionBlock,
Downsample,
@ -19,8 +16,6 @@ from imaginairy.modules.diffusion.util import (
timestep_embedding,
zero_module,
)
from imaginairy.samplers import DDIMSampler
from imaginairy.utils import instantiate_from_config
class ControlledUnetModel(UNetModel):
@ -382,164 +377,56 @@ class ControlLDM(LatentDiffusion):
**kwargs,
):
super().__init__(*args, **kwargs)
self.control_model = instantiate_from_config(control_stage_config)
self.control_stage_config = control_stage_config
# self.control_model = instantiate_from_config(control_stage_config)
self.control_models = []
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
self.global_average_pooling = global_average_pooling
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs): # noqa
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, "b h w c -> b c h w")
control = control.to(memory_format=torch.contiguous_format).float()
return x, {"c_crossattn": [c], "c_concat": [control]}
def set_control_models(self, control_models):
self.control_models = control_models
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
merged_control = None
cond_txt = torch.cat(cond["c_crossattn"], 1)
cond_hint = torch.cat(cond["c_concat"], 1)
control = None
if cond["c_concat"] is not None:
control = self.control_model(
for control_model, c_concat in zip(self.control_models, cond["c_concat"]):
cond_hint = torch.cat([c_concat], 1)
control = control_model(
x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt
)
control = [c * scale for c, scale in zip(control, self.control_scales)]
if self.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]
if merged_control is None:
merged_control = control
else:
merged_control = [mc + c for mc, c in zip(merged_control, control)]
eps = diffusion_model(
x=x_noisy,
timesteps=t,
context=cond_txt,
control=control,
control=merged_control,
only_mid_control=self.only_mid_control,
)
return eps
@torch.no_grad()
def get_unconditional_conditioning(self, N):
return self.get_learned_conditioning([""] * N)
@torch.no_grad()
def log_images(
self,
batch,
N=4,
n_row=2,
sample=False,
ddim_steps=50,
ddim_eta=0.0,
return_keys=None,
quantize_denoised=True,
inpaint=True,
plot_denoise_rows=False,
plot_progressive_rows=True,
plot_diffusion_rows=False,
unconditional_guidance_scale=9.0,
unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs,
):
use_ddim = ddim_steps is not None
log = {}
z, c = self.get_input(batch, self.first_stage_key, bs=N)
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log["reconstruction"] = self.decode_first_stage(z)
log["control"] = c_cat * 2.0 - 1.0
log["conditioning"] = log_txt_as_img(
(512, 512), batch[self.cond_stage_key], size=16
)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = []
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), "1 -> b", b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, "n b c h w -> b n c h w")
diffusion_grid = rearrange(diffusion_grid, "b n c h w -> (b n) c h w")
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(
cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N)
uc_cat = c_cat # torch.zeros_like(c_cat)
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
samples_cfg, _ = self.sample_log(
cond={"c_concat": [c_cat], "c_crossattn": [c]},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs
)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if not self.sd_locked:
params += list(self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.control_models = [cm.cuda() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cpu() # noqa
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.control_models = [cm.cpu() for cm in self.control_models]
self.first_stage_model = self.first_stage_model.cuda() # noqa
self.cond_stage_model = self.cond_stage_model.cuda()

@ -1,6 +1,7 @@
import importlib
import logging
import platform
import time
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional, Union
@ -59,7 +60,11 @@ def instantiate_from_config(config: Union[dict, str]) -> Any:
raise KeyError("Expected key `target` to instantiate.")
params = config.get("params", {})
_cls = get_obj_from_str(config["target"])
return _cls(**params)
start = time.perf_counter()
c = _cls(**params)
end = time.perf_counter()
logger.debug(f"Instantiation of {_cls} took {end-start} seconds")
return c
@contextmanager

@ -0,0 +1,363 @@
# pylama: ignore=W0212
import logging
from collections import OrderedDict
from functools import cached_property
from imaginairy.utils import get_device
logger = logging.getLogger(__name__)
def get_model_size(model):
from torch import nn
if not isinstance(model, nn.Module) and hasattr(model, "model"):
model = model.model
return sum(v.numel() * v.element_size() for v in model.parameters())
def move_model_device(model, device):
from torch import nn
if not isinstance(model, nn.Module) and hasattr(model, "model"):
model = model.model
return model.to(device)
class MemoryTrackingCache:
def __init__(self, *args, **kwargs):
self.memory_usage = 0
self._item_memory_usage = {}
self._cache = OrderedDict()
super().__init__(*args, **kwargs)
def first_key(self):
if self._cache:
return next(iter(self._cache))
raise KeyError("Empty dictionary")
def last_key(self):
if self._cache:
return next(reversed(self._cache))
raise KeyError("Empty dictionary")
def set(self, key, value, memory_usage=None):
if key in self._cache:
# Subtract old item memory usage if key already exists
self.memory_usage -= self._item_memory_usage[key]
self._cache[key] = value
# Calculate and store new item memory usage
item_memory_usage = max(get_model_size(value), memory_usage)
self._item_memory_usage[key] = item_memory_usage
self.memory_usage += item_memory_usage
def pop(self, key):
# Subtract item memory usage before deletion
self.memory_usage -= self._item_memory_usage[key]
del self._item_memory_usage[key]
return self._cache.pop(key)
def move_to_end(self, key, last=True):
self._cache.move_to_end(key, last=last)
def __contains__(self, item):
return item in self._cache
def __delitem__(self, key):
self.pop(key)
def __getitem__(self, item):
return self._cache[item]
def get(self, item):
return self._cache.get(item)
def __len__(self):
return len(self._cache)
def __bool__(self):
return bool(self._cache)
def keys(self):
return self._cache.keys()
def get_mem_free_total(device):
import psutil
import torch
if device.type == "cuda":
if not torch.cuda.is_initialized():
torch.cuda.init()
stats = torch.cuda.memory_stats(device)
mem_active = stats["active_bytes.all.current"]
mem_reserved = stats["reserved_bytes.all.current"]
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
mem_free_total *= 0.9
else:
# if we don't add a buffer, larger images come out as noise
mem_free_total = psutil.virtual_memory().available * 0.6
return mem_free_total
class GPUModelCache:
def __init__(self, max_cpu_memory_gb="80%", max_gpu_memory_gb="95%", device=None):
self._device = device
if device in ("cpu", "mps"):
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
max_cpu_memory_gb = 0
self._max_cpu_memory_gb = max_cpu_memory_gb
self._max_gpu_memory_gb = max_gpu_memory_gb
self.gpu_cache = MemoryTrackingCache()
self.cpu_cache = MemoryTrackingCache()
@cached_property
def device(self):
import torch
if self._device is None:
self._device = get_device()
if self._device in ("cpu", "mps"):
# the "gpu" cache will be the only thing we use since there aren't two different memory stores in this case
self._max_cpu_memory_gb = 0
return torch.device(self._device)
def make_gpu_space(self, bytes_to_free):
import gc
import torch.cuda
mem_free = get_mem_free_total(self.device)
logger.debug(
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of GPU space. current usage: {self.gpu_cache.memory_usage / (1024 ** 2):.1f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; Max mem: {self.max_gpu_memory / (1024 ** 2):.1f} MB"
)
while self.gpu_cache and (
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.gpu_cache.memory_usage + bytes_to_free
> get_mem_free_total(self.device)
):
oldest_gpu_key = self.gpu_cache.first_key()
self._move_to_cpu(oldest_gpu_key)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if (
self.gpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.gpu_cache.memory_usage + bytes_to_free
> get_mem_free_total(self.device)
):
raise RuntimeError("Unable to make space on GPU")
def make_cpu_space(self, bytes_to_free):
import gc
import psutil
mem_free = psutil.virtual_memory().available * 0.8
logger.debug(
f"Making {bytes_to_free / (1024 ** 2):.1f} MB of RAM space. current usage: {self.cpu_cache.memory_usage / (1024 ** 2):.2f} MB; free mem: {mem_free / (1024 ** 2):.1f} MB; max mem: {self.max_cpu_memory / (1024 ** 2):.1f} MB"
)
while self.cpu_cache and (
self.cpu_cache.memory_usage + bytes_to_free > self.max_gpu_memory
or self.cpu_cache.memory_usage + bytes_to_free
> psutil.virtual_memory().available * 0.8
):
oldest_cpu_key = self.cpu_cache.first_key()
logger.debug(f"dropping {oldest_cpu_key} from memory")
self.cpu_cache.pop(oldest_cpu_key)
gc.collect()
@cached_property
def max_cpu_memory(self):
_ = self.device
if isinstance(self._max_cpu_memory_gb, str):
if self._max_cpu_memory_gb.endswith("%"):
import psutil
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_cpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
raise ValueError(
f"Invalid value for max_cpu_memory_gb: {self._max_cpu_memory_gb}"
)
return self._max_cpu_memory_gb * (1024**3)
@cached_property
def max_gpu_memory(self):
_ = self.device
if isinstance(self._max_gpu_memory_gb, str):
if self._max_gpu_memory_gb.endswith("%"):
import torch
if self.device.type == "cuda":
device_props = torch.cuda.get_device_properties(0)
total_ram_gb = device_props.total_memory / (1024**3)
else:
import psutil
total_ram_gb = round(psutil.virtual_memory().total / (1024**3), 2)
pct_to_use = float(self._max_gpu_memory_gb[:-1]) / 100.0
return total_ram_gb * pct_to_use * (1024**3)
raise ValueError(
f"Invalid value for max_gpu_memory_gb: {self._max_gpu_memory_gb}"
)
return self._max_gpu_memory_gb * (1024**3)
def _move_to_gpu(self, key, model):
model_size = get_model_size(model)
if self.gpu_cache.memory_usage + model_size > self.max_gpu_memory:
if len(self.gpu_cache) == 0:
msg = f"GPU cache maximum ({self.max_gpu_memory / (1024 ** 2)} MB) is smaller than the item being cached ({model_size / 1024 ** 2} MB)."
raise RuntimeError(msg)
self.make_gpu_space(model_size)
try:
model_size = max(self.cpu_cache._item_memory_usage[key], model_size)
self.cpu_cache.pop(key)
except KeyError:
pass
logger.debug(f"moving {key} to gpu")
move_model_device(model, self.device)
self.gpu_cache.set(key, value=model, memory_usage=model_size)
def _move_to_cpu(self, key):
import torch
import psutil
memory_usage = self.gpu_cache._item_memory_usage[key]
model = self.gpu_cache.pop(key)
model_size = max(get_model_size(model), memory_usage)
self.make_cpu_space(model_size)
if (
self.cpu_cache.memory_usage + model_size < self.max_cpu_memory
and self.cpu_cache.memory_usage + model_size
< psutil.virtual_memory().available * 0.8
):
logger.debug(f"moving {key} to cpu")
move_model_device(model, torch.device("cpu"))
self.cpu_cache.set(key, model, memory_usage=model_size)
else:
logger.debug(f"dropping {key} from memory")
def get(self, key):
import torch
if key not in self:
raise KeyError(f"The key {key} does not exist in the cache")
if key in self.cpu_cache:
if self.device != torch.device("cpu"):
self.cpu_cache.move_to_end(key)
self._move_to_gpu(key, self.cpu_cache[key])
if key in self.gpu_cache:
self.gpu_cache.move_to_end(key)
model = self.gpu_cache.get(key)
return model
def __getitem__(self, key):
return self.get(key)
def set(self, key, model, memory_usage=0):
from torch import nn
if (
hasattr(model, "model") and isinstance(model.model, nn.Module)
) or isinstance(model, nn.Module):
pass
else:
raise ValueError("Only nn.Module objects can be cached")
model_size = max(get_model_size(model), memory_usage)
self.make_gpu_space(model_size)
self._move_to_gpu(key, model)
def __contains__(self, key):
return key in self.gpu_cache or key in self.cpu_cache
def keys(self):
return list(self.cpu_cache.keys()) + list(self.gpu_cache.keys())
def stats(self):
return {
"cpu_cache_count": len(self.cpu_cache),
"cpu_cache_memory_usage": self.cpu_cache.memory_usage,
"cpu_cache_max_memory": self.max_cpu_memory,
"gpu_cache_count": len(self.gpu_cache),
"gpu_cache_memory_usage": self.gpu_cache.memory_usage,
"gpu_cache_max_memory": self.max_gpu_memory,
}
class MemoryManagedModelWrapper:
_mmmw_cache = GPUModelCache()
def __init__(self, fn, namespace, estimated_ram_size_mb, *args, **kwargs):
self._mmmw_fn = fn
self._mmmw_args = args
self._mmmw_kwargs = kwargs
self._mmmw_namespace = namespace
self._mmmw_estimated_ram_size_mb = estimated_ram_size_mb
self._mmmw_cache_key = (namespace,) + args + tuple(kwargs.items())
def _mmmw_load_model(self):
if self._mmmw_cache_key not in self.__class__._mmmw_cache:
logger.debug(f"Loading model: {self._mmmw_cache_key}")
self.__class__._mmmw_cache.make_gpu_space(
self._mmmw_estimated_ram_size_mb * 1024**2
)
free_before = get_mem_free_total(self.__class__._mmmw_cache.device)
model = self._mmmw_fn(*self._mmmw_args, **self._mmmw_kwargs)
move_model_device(model, self.__class__._mmmw_cache.device)
free_after = get_mem_free_total(self.__class__._mmmw_cache.device)
logger.debug(
f"Model loaded: {self._mmmw_cache_key} Used {free_before - free_after}"
)
self.__class__._mmmw_cache.set(
self._mmmw_cache_key,
model,
memory_usage=self._mmmw_estimated_ram_size_mb * 1024**2,
)
model = self.__class__._mmmw_cache[self._mmmw_cache_key]
return model
def __getattr__(self, name):
model = self._mmmw_load_model()
return getattr(model, name)
def __call__(self, *args, **kwargs):
model = self._mmmw_load_model()
return model(*args, **kwargs)
def memory_managed_model(namespace, memory_usage_mb=0):
def decorator(fn):
def wrapper(*args, **kwargs):
return MemoryManagedModelWrapper(
fn, namespace, memory_usage_mb, *args, **kwargs
)
return wrapper
return decorator

@ -0,0 +1,156 @@
import pytest
import torch
from torch import nn
from imaginairy import ImaginePrompt, imagine
from imaginairy.utils import get_device
from imaginairy.utils.model_cache import GPUModelCache
class DummyMemoryModule(nn.Module):
def __init__(self, in_features):
super().__init__()
self.large_layer = nn.Linear(in_features - 1, 1)
def forward(self, x):
return self.large_layer(x)
def create_model_of_n_bytes(n):
import math
n = int(math.floor(n/4))
return DummyMemoryModule(n)
@pytest.mark.skip()
@pytest.mark.parametrize(
"model_version",
[
"SD-1.4",
"SD-1.5",
"SD-2.0",
"SD-2.0-v",
"SD-2.1",
"SD-2.1-v",
"openjourney-v1",
"openjourney-v2",
"openjourney-v4",
],
)
def test_memory_usage(filename_base_for_orig_outputs, model_version):
"""Test that we can switch between model versions."""
prompt_text = "valley, fairytale treehouse village covered, , matte painting, highly detailed, dynamic lighting, cinematic, realism, realistic, photo real, sunset, detailed, high contrast, denoised, centered, michael whelan"
prompts = [ImaginePrompt(prompt_text, model=model_version, seed=1, steps=30)]
for i, result in enumerate(imagine(prompts)):
img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png"
result.img.save(img_path)
def test_get_nonexistent():
cache = GPUModelCache(max_cpu_memory_gb=1, max_gpu_memory_gb=1)
with pytest.raises(KeyError):
cache.get("nonexistent_key")
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
def test_set_cpu_full():
cache = GPUModelCache(
max_cpu_memory_gb=0.000000001, max_gpu_memory_gb=0.01, device=get_device()
)
for i in range(4):
cache.set(f"key{i}", create_model_of_n_bytes(4_000_000))
assert len(cache.cpu_cache) == 0
assert len(cache.gpu_cache) == 2
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
def test_set_gpu_full():
cache = GPUModelCache(
max_cpu_memory_gb=1, max_gpu_memory_gb=0.0000001, device=get_device()
)
assert cache.max_cpu_memory == 1073741824
model = create_model_of_n_bytes(100_000)
with pytest.raises(RuntimeError):
cache.set("key1", model)
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
def test_get_existing_cpu():
cache = GPUModelCache(max_cpu_memory_gb=0.1, max_gpu_memory_gb=0.1, device="cpu")
model = create_model_of_n_bytes(10_000)
cache.set("key", model)
retrieved_data = cache.get("key")
assert retrieved_data == model
# assert 'key' in cache.cpu_cache
assert "key" in cache.gpu_cache
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
def test_get_existing_move_to_gpu():
cache = GPUModelCache(
max_cpu_memory_gb=0.1, max_gpu_memory_gb=0.1, device=get_device()
)
model = create_model_of_n_bytes(10_000)
cache.set("key", model)
retrieved_data = cache.get("key")
assert retrieved_data == model
assert "key" not in cache.cpu_cache
assert "key" in cache.gpu_cache
@pytest.mark.skipif(get_device() == "cpu", reason="GPU not available")
def test_cache_ordering():
cache = GPUModelCache(
max_cpu_memory_gb=0.01, max_gpu_memory_gb=0.01, device=get_device()
)
cache.set("key-0", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa
assert list(cache.gpu_cache.keys()) == ["key-0"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0,
4_000_000,
)
cache.set("key-1", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == [] # noqa
assert list(cache.gpu_cache.keys()) == ["key-0", "key-1"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
0,
8_000_000,
)
cache.set("key-2", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == ["key-0"]
assert list(cache.gpu_cache.keys()) == ["key-1", "key-2"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
4_000_000,
8_000_000,
)
cache.set("key-3", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == ["key-0", "key-1"]
assert list(cache.gpu_cache.keys()) == ["key-2", "key-3"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
8_000_000,
8_000_000,
)
cache.set("key-4", create_model_of_n_bytes(4_000_000))
assert list(cache.cpu_cache.keys()) == ["key-1", "key-2"]
assert list(cache.gpu_cache.keys()) == ["key-3", "key-4"]
assert list(cache.keys()) == ["key-1", "key-2", "key-3", "key-4"]
assert (cache.cpu_cache.memory_usage, cache.gpu_cache.memory_usage) == (
8_000_000,
8_000_000,
)
cache.get("key-2")
assert list(cache.keys()) == ["key-3", "key-4", "key-2"]
cache.set("key-5", create_model_of_n_bytes(9_000_000))
assert list(cache.cpu_cache.keys()) == ["key-4", "key-2"]
assert list(cache.gpu_cache.keys()) == ["key-5"]
assert list(cache.keys()) == ["key-4", "key-2", "key-5"]
Loading…
Cancel
Save