2023-12-15 20:31:28 +00:00
""" Classes and functions for managing AI models """
2022-10-24 05:42:17 +00:00
import logging
import os
2023-01-26 06:05:07 +00:00
import re
2022-12-09 09:14:47 +00:00
import sys
2023-01-09 04:45:58 +00:00
import urllib . parse
2023-11-16 03:46:56 +00:00
from functools import lru_cache , wraps
2022-10-24 05:42:17 +00:00
2022-10-23 21:46:45 +00:00
import requests
2022-10-24 05:42:17 +00:00
import torch
2023-09-29 08:13:50 +00:00
from huggingface_hub import (
2023-12-21 22:24:35 +00:00
HfFileSystem ,
2023-09-29 08:13:50 +00:00
HfFolder ,
hf_hub_download as _hf_hub_download ,
try_to_load_from_cache ,
)
2022-10-24 05:42:17 +00:00
from omegaconf import OmegaConf
2023-01-18 06:43:23 +00:00
from safetensors . torch import load_file
2022-10-24 05:42:17 +00:00
2022-11-26 22:52:28 +00:00
from imaginairy import config as iconfig
2023-12-08 04:57:55 +00:00
from imaginairy . config import IMAGE_WEIGHTS_SHORT_NAMES , ModelArchitecture
2022-12-07 18:16:38 +00:00
from imaginairy . modules import attention
2023-12-28 05:52:37 +00:00
from imaginairy . modules . refiners_sd import SDXLAutoencoderSliced , StableDiffusion_XL
from imaginairy . utils import clear_gpu_cache , get_device , instantiate_from_config
2023-05-08 03:37:15 +00:00
from imaginairy . utils . model_cache import memory_managed_model
2023-12-11 16:39:09 +00:00
from imaginairy . utils . named_resolutions import normalize_image_size
2023-12-15 21:40:10 +00:00
from imaginairy . utils . paths import PKG_ROOT
2024-01-03 05:06:39 +00:00
from imaginairy . vendored . refiners . foundationals . clip . text_encoder import (
CLIPTextEncoderL ,
)
from imaginairy . vendored . refiners . foundationals . latent_diffusion import (
DoubleTextEncoder ,
SD1UNet ,
SDXLUNet ,
)
from imaginairy . vendored . refiners . foundationals . latent_diffusion . model import (
LatentDiffusionModel ,
)
2023-12-28 05:52:37 +00:00
from imaginairy . weight_management import translators
2024-01-13 21:43:15 +00:00
from imaginairy . weight_management . translators import (
DoubleTextEncoderTranslator ,
diffusers_autoencoder_kl_to_refiners_translator ,
diffusers_unet_sdxl_to_refiners_translator ,
load_weight_map ,
)
2022-10-24 05:42:17 +00:00
logger = logging . getLogger ( __name__ )
2022-10-23 21:46:45 +00:00
MOST_RECENTLY_LOADED_MODEL = None
class HuggingFaceAuthorizationError ( RuntimeError ) :
pass
2022-10-24 05:42:17 +00:00
2023-05-08 03:37:15 +00:00
def load_state_dict ( weights_location , half_mode = False , device = None ) :
if device is None :
device = get_device ( )
2022-10-24 05:42:17 +00:00
if weights_location . startswith ( " http " ) :
2023-01-01 22:54:49 +00:00
ckpt_path = get_cached_url_path ( weights_location , category = " weights " )
2022-10-24 05:42:17 +00:00
else :
ckpt_path = weights_location
logger . info ( f " Loading model { ckpt_path } onto { get_device ( ) } backend... " )
2023-02-12 07:42:19 +00:00
state_dict = None
2023-05-08 03:37:15 +00:00
# weights_cache_key = (ckpt_path, half_mode)
# if weights_cache_key in GLOBAL_WEIGHTS_CACHE:
# return GLOBAL_WEIGHTS_CACHE.get(weights_cache_key)
2022-10-24 05:42:17 +00:00
try :
2023-02-12 07:42:19 +00:00
state_dict = load_tensors ( ckpt_path , map_location = " cpu " )
2022-12-09 09:14:47 +00:00
except FileNotFoundError as e :
if e . errno == 2 :
logger . error (
2023-12-08 04:57:55 +00:00
f ' Error: " { ckpt_path } " not a valid path to model weights. \n Preconfigured models you can use: { IMAGE_WEIGHTS_SHORT_NAMES } . '
2022-12-18 08:00:38 +00:00
)
2022-12-09 09:14:47 +00:00
sys . exit ( 1 )
2023-09-29 08:13:50 +00:00
raise
2022-10-24 05:42:17 +00:00
except RuntimeError as e :
2023-09-29 08:13:50 +00:00
err_str = str ( e )
if (
" PytorchStreamReader failed reading zip archive " in err_str
and weights_location . startswith ( " http " )
) :
logger . warning ( " Corrupt checkpoint. deleting and re-downloading... " )
os . remove ( ckpt_path )
ckpt_path = get_cached_url_path ( weights_location , category = " weights " )
state_dict = load_tensors ( ckpt_path , map_location = " cpu " )
2023-02-12 07:42:19 +00:00
if state_dict is None :
2023-09-29 08:13:50 +00:00
raise
2023-02-12 07:42:19 +00:00
state_dict = state_dict . get ( " state_dict " , state_dict )
2023-05-08 03:37:15 +00:00
if half_mode :
state_dict = { k : v . half ( ) for k , v in state_dict . items ( ) }
# change device
state_dict = { k : v . to ( device ) for k , v in state_dict . items ( ) }
# GLOBAL_WEIGHTS_CACHE.set(weights_cache_key, state_dict)
2023-02-12 07:42:19 +00:00
return state_dict
2023-05-08 03:37:15 +00:00
def load_model_from_config ( config , weights_location , half_mode = False ) :
model = instantiate_from_config ( config . model )
base_model_dict = load_state_dict ( weights_location , half_mode = half_mode )
model . init_from_state_dict ( base_model_dict )
if half_mode :
model = model . half ( )
model . to ( get_device ( ) )
model . eval ( )
return model
2023-02-12 07:42:19 +00:00
def add_controlnet ( base_state_dict , controlnet_state_dict ) :
""" Merges a base sd15 model with a controlnet model. """
for key in controlnet_state_dict :
2023-05-01 04:57:39 +00:00
base_state_dict [ key ] = controlnet_state_dict [ key ]
2023-02-12 07:42:19 +00:00
return base_state_dict
2022-10-24 05:42:17 +00:00
def get_diffusion_model (
2023-12-08 04:57:55 +00:00
weights_location = iconfig . DEFAULT_MODEL_WEIGHTS ,
2022-10-24 05:42:17 +00:00
config_path = " configs/stable-diffusion-v1.yaml " ,
2023-05-08 03:37:15 +00:00
control_weights_locations = None ,
2022-10-24 05:42:17 +00:00
half_mode = None ,
2022-10-23 21:46:45 +00:00
for_inpainting = False ,
) :
"""
2023-01-02 04:14:22 +00:00
Load a diffusion model .
2022-10-23 21:46:45 +00:00
Weights location may also be shortcut name , e . g . " SD-1.5 "
"""
try :
return _get_diffusion_model (
2023-01-01 22:54:49 +00:00
weights_location ,
config_path ,
half_mode ,
for_inpainting ,
2023-05-08 03:37:15 +00:00
control_weights_locations = control_weights_locations ,
2022-10-23 21:46:45 +00:00
)
except HuggingFaceAuthorizationError as e :
if for_inpainting :
logger . warning (
2023-09-29 08:13:50 +00:00
f " Failed to load inpainting model. Attempting to fall-back to standard model. { e !s} "
2022-10-23 21:46:45 +00:00
)
return _get_diffusion_model (
2023-12-08 04:57:55 +00:00
iconfig . DEFAULT_MODEL_WEIGHTS ,
2023-01-01 22:54:49 +00:00
config_path ,
half_mode ,
for_inpainting = False ,
2023-05-08 03:37:15 +00:00
control_weights_locations = control_weights_locations ,
2022-10-23 21:46:45 +00:00
)
2023-09-29 08:13:50 +00:00
raise
2022-10-23 21:46:45 +00:00
def _get_diffusion_model (
2023-12-08 04:57:55 +00:00
weights_location = iconfig . DEFAULT_MODEL_WEIGHTS ,
model_architecture = " configs/stable-diffusion-v1.yaml " ,
2022-10-23 21:46:45 +00:00
half_mode = None ,
for_inpainting = False ,
2023-05-08 03:37:15 +00:00
control_weights_locations = None ,
2022-10-24 05:42:17 +00:00
) :
"""
2023-01-02 04:14:22 +00:00
Load a diffusion model .
2022-10-24 05:42:17 +00:00
Weights location may also be shortcut name , e . g . " SD-1.5 "
"""
2023-09-29 08:13:50 +00:00
global MOST_RECENTLY_LOADED_MODEL
2023-05-16 04:24:03 +00:00
2023-12-08 04:57:55 +00:00
model_weights_config = resolve_model_weights_config (
model_weights = weights_location ,
default_model_architecture = model_architecture ,
2023-01-01 22:54:49 +00:00
for_inpainting = for_inpainting ,
)
2022-12-07 18:16:38 +00:00
# some models need the attention calculated in float32
2023-12-08 04:57:55 +00:00
if model_weights_config is not None :
attention . ATTENTION_PRECISION_OVERRIDE = (
model_weights_config . forced_attn_precision
)
2022-12-07 18:16:38 +00:00
else :
attention . ATTENTION_PRECISION_OVERRIDE = " default "
2023-05-08 03:37:15 +00:00
diffusion_model = _load_diffusion_model (
2023-12-08 04:57:55 +00:00
config_path = model_weights_config . architecture . config_path ,
2023-05-08 03:37:15 +00:00
weights_location = weights_location ,
half_mode = half_mode ,
)
2023-05-16 04:24:03 +00:00
MOST_RECENTLY_LOADED_MODEL = diffusion_model
2023-05-08 03:37:15 +00:00
if control_weights_locations :
controlnets = [ ]
for control_weights_location in control_weights_locations :
controlnets . append ( load_controlnet ( control_weights_location , half_mode ) )
diffusion_model . set_control_models ( controlnets )
2022-12-07 18:16:38 +00:00
2023-05-08 03:37:15 +00:00
return diffusion_model
2022-10-24 05:42:17 +00:00
2023-05-08 03:37:15 +00:00
2023-11-16 03:46:56 +00:00
def get_diffusion_model_refiners (
2023-12-11 16:39:09 +00:00
weights_config : iconfig . ModelWeightsConfig ,
2023-11-16 03:46:56 +00:00
for_inpainting = False ,
2023-12-11 16:39:09 +00:00
dtype = None ,
) - > LatentDiffusionModel :
""" Load a diffusion model. """
2023-12-28 05:52:37 +00:00
2023-12-31 05:21:49 +00:00
sd = _get_diffusion_model_refiners (
2023-12-11 16:39:09 +00:00
weights_location = weights_config . weights_location ,
2023-12-28 05:52:37 +00:00
architecture_alias = weights_config . architecture . primary_alias ,
2023-11-16 03:46:56 +00:00
for_inpainting = for_inpainting ,
dtype = dtype ,
)
2023-12-31 05:21:49 +00:00
# ensures a "fresh" copy that doesn't have additional injected parts
2024-01-02 02:35:14 +00:00
sd = sd . structural_copy ( )
2023-12-31 05:21:49 +00:00
sd . set_self_attention_guidance ( enable = True )
return sd
2023-11-16 03:46:56 +00:00
2023-12-21 22:24:35 +00:00
hf_repo_url_pattern = re . compile (
r " https://huggingface \ .co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$ "
)
def parse_diffusers_repo_url ( url : str ) - > dict [ str , str ] :
match = hf_repo_url_pattern . match ( url )
return match . groupdict ( ) if match else { }
def is_diffusers_repo_url ( url : str ) - > bool :
2023-12-31 05:21:49 +00:00
result = bool ( parse_diffusers_repo_url ( url ) )
logger . debug ( f " { url } is diffusers repo url: { result } " )
return result
2023-12-21 22:24:35 +00:00
def normalize_diffusers_repo_url ( url : str ) - > str :
data = parse_diffusers_repo_url ( url )
ref = data [ " ref " ] or " main "
normalized_url = (
f " https://huggingface.co/ { data [ ' author ' ] } / { data [ ' repo ' ] } /tree/ { ref } / "
)
return normalized_url
2023-11-16 03:46:56 +00:00
@lru_cache ( maxsize = 1 )
2023-12-11 16:39:09 +00:00
def _get_diffusion_model_refiners (
2023-12-28 05:52:37 +00:00
weights_location : str ,
architecture_alias : str ,
for_inpainting : bool = False ,
device = None ,
dtype = torch . float16 ,
) - > LatentDiffusionModel :
"""
Load a diffusion model .
Weights location may also be shortcut name , e . g . " SD-1.5 "
"""
global MOST_RECENTLY_LOADED_MODEL
_get_diffusion_model_refiners . cache_clear ( )
clear_gpu_cache ( )
architecture = iconfig . MODEL_ARCHITECTURE_LOOKUP [ architecture_alias ]
if architecture . primary_alias in ( " sd15 " , " sd15inpaint " ) :
2024-01-13 21:43:15 +00:00
sd = load_sd15_pipeline (
2023-12-28 05:52:37 +00:00
weights_location = weights_location ,
for_inpainting = for_inpainting ,
device = device ,
dtype = dtype ,
)
elif architecture . primary_alias == " sdxl " :
sd = load_sdxl_pipeline ( base_url = weights_location , device = device )
else :
msg = f " Invalid architecture { architecture . primary_alias } "
raise ValueError ( msg )
MOST_RECENTLY_LOADED_MODEL = sd
2024-01-02 02:35:14 +00:00
msg = (
2024-01-13 21:43:15 +00:00
" Pipeline loaded "
f " sd[dtype: { sd . dtype } device: { sd . device } ] "
f " sd.unet[dtype: { sd . unet . dtype } device: { sd . unet . device } ] "
f " sd.lda[dtype: { sd . lda . dtype } device: { sd . lda . device } ] "
f " sd.clip_text_encoder[dtype: { sd . clip_text_encoder . dtype } device: { sd . clip_text_encoder . device } ] "
2024-01-02 02:35:14 +00:00
)
logger . debug ( msg )
2023-12-28 05:52:37 +00:00
return sd
2024-01-02 02:35:14 +00:00
# new
2024-01-13 21:43:15 +00:00
def load_sd15_pipeline (
2023-12-11 16:39:09 +00:00
weights_location : str ,
for_inpainting : bool = False ,
2023-11-16 03:46:56 +00:00
device = None ,
dtype = torch . float16 ,
2023-12-11 16:39:09 +00:00
) - > LatentDiffusionModel :
2023-11-16 03:46:56 +00:00
"""
Load a diffusion model .
Weights location may also be shortcut name , e . g . " SD-1.5 "
"""
from imaginairy . modules . refiners_sd import (
SD1AutoencoderSliced ,
StableDiffusion_1 ,
StableDiffusion_1_Inpainting ,
)
device = device or get_device ( )
2023-12-21 22:24:35 +00:00
if is_diffusers_repo_url ( weights_location ) :
(
vae_weights ,
unet_weights ,
text_encoder_weights ,
2024-01-02 02:35:14 +00:00
) = load_sd15_diffusers_weights ( weights_location , device = " cpu " )
2023-12-21 22:24:35 +00:00
else :
(
vae_weights ,
unet_weights ,
text_encoder_weights ,
) = load_stable_diffusion_compvis_weights ( weights_location )
2023-12-12 06:29:36 +00:00
StableDiffusionCls : type [ LatentDiffusionModel ]
2023-11-16 03:46:56 +00:00
if for_inpainting :
2024-01-02 02:35:14 +00:00
unet = SD1UNet ( in_channels = 9 , device = " cpu " , dtype = dtype )
2023-11-16 03:46:56 +00:00
StableDiffusionCls = StableDiffusion_1_Inpainting
else :
2024-01-02 02:35:14 +00:00
unet = SD1UNet ( in_channels = 4 , device = " cpu " , dtype = dtype )
2023-11-16 03:46:56 +00:00
StableDiffusionCls = StableDiffusion_1
logger . debug ( f " Using class { StableDiffusionCls . __name__ } " )
sd = StableDiffusionCls (
device = device , dtype = dtype , lda = SD1AutoencoderSliced ( ) , unet = unet
)
logger . debug ( " Loading VAE " )
2023-12-31 05:21:49 +00:00
sd . lda . load_state_dict ( vae_weights , assign = True )
2023-11-16 03:46:56 +00:00
logger . debug ( " Loading text encoder " )
2023-12-31 05:21:49 +00:00
sd . clip_text_encoder . load_state_dict ( text_encoder_weights , assign = True )
2023-11-16 03:46:56 +00:00
logger . debug ( " Loading UNet " )
2023-12-31 05:21:49 +00:00
sd . unet . load_state_dict ( unet_weights , strict = False , assign = True )
2023-11-16 03:46:56 +00:00
logger . debug ( f " ' { weights_location } ' Loaded " )
2023-12-31 05:21:49 +00:00
sd . to ( device = device , dtype = dtype )
2024-01-02 02:35:14 +00:00
return sd
def _get_sd15_diffusion_model_refiners_new (
weights_location : str ,
for_inpainting : bool = False ,
device = None ,
dtype = torch . float16 ,
) - > LatentDiffusionModel :
"""
Load a diffusion model .
Weights location may also be shortcut name , e . g . " SD-1.5 "
"""
from imaginairy . modules . refiners_sd import (
SD1AutoencoderSliced ,
StableDiffusion_1 ,
StableDiffusion_1_Inpainting ,
)
device = device or get_device ( )
if is_diffusers_repo_url ( weights_location ) :
(
vae_weights ,
unet_weights ,
text_encoder_weights ,
) = load_sd15_diffusers_weights ( weights_location , device = " cpu " )
else :
(
vae_weights ,
unet_weights ,
text_encoder_weights ,
) = load_stable_diffusion_compvis_weights ( weights_location )
StableDiffusionCls : type [ LatentDiffusionModel ]
if for_inpainting :
unet = SD1UNet ( in_channels = 9 , device = " cpu " , dtype = dtype )
StableDiffusionCls = StableDiffusion_1_Inpainting
else :
unet = SD1UNet ( in_channels = 4 , device = " cpu " , dtype = dtype )
StableDiffusionCls = StableDiffusion_1
logger . debug ( " Loading UNet " )
unet . load_state_dict ( unet_weights , strict = False , assign = True )
del unet_weights
unet . to ( device = device , dtype = dtype )
logger . debug ( " Loading VAE " )
lda = SD1AutoencoderSliced ( device = device , dtype = dtype )
lda . load_state_dict ( vae_weights , assign = True )
del vae_weights
lda . to ( device = device , dtype = dtype )
logger . debug ( " Loading text encoder " )
clip_text_encoder = CLIPTextEncoderL ( )
clip_text_encoder . load_state_dict ( text_encoder_weights , assign = True )
del text_encoder_weights
clip_text_encoder . to ( device = device , dtype = dtype )
logger . debug ( f " Using class { StableDiffusionCls . __name__ } " )
sd = StableDiffusionCls ( device = None , dtype = dtype , lda = lda , unet = unet ) # type: ignore
sd . to ( device = device , dtype = dtype )
2023-11-16 03:46:56 +00:00
2024-01-02 02:35:14 +00:00
logger . debug ( f " ' { weights_location } ' Loaded " )
2023-11-16 03:46:56 +00:00
return sd
2023-05-08 03:37:15 +00:00
@memory_managed_model ( " stable-diffusion " , memory_usage_mb = 1951 )
2023-12-07 05:51:36 +00:00
def _load_diffusion_model ( config_path , weights_location , half_mode ) :
2023-05-08 03:37:15 +00:00
model_config = OmegaConf . load ( f " { PKG_ROOT } / { config_path } " )
# only run half-mode on cuda. run it by default
half_mode = half_mode is None and get_device ( ) == " cuda "
model = load_model_from_config (
config = model_config ,
weights_location = weights_location ,
half_mode = half_mode ,
)
2022-10-24 05:42:17 +00:00
return model
2022-10-23 21:46:45 +00:00
2023-05-08 03:37:15 +00:00
@memory_managed_model ( " controlnet " )
def load_controlnet ( control_weights_location , half_mode ) :
controlnet_state_dict = load_state_dict (
control_weights_location , half_mode = half_mode
)
controlnet_state_dict = {
k . replace ( " control_model. " , " " ) : v for k , v in controlnet_state_dict . items ( )
}
control_stage_config = OmegaConf . load ( f " { PKG_ROOT } /configs/control-net-v15.yaml " ) [
" model "
] [ " params " ] [ " control_stage_config " ]
controlnet = instantiate_from_config ( control_stage_config )
2024-01-02 02:35:14 +00:00
controlnet . load_state_dict ( controlnet_state_dict , assign = True )
2023-05-08 03:37:15 +00:00
controlnet . to ( get_device ( ) )
return controlnet
2023-12-08 04:57:55 +00:00
def resolve_model_weights_config (
2023-12-11 16:39:09 +00:00
model_weights : str | iconfig . ModelWeightsConfig ,
2023-12-08 04:57:55 +00:00
default_model_architecture : str | None = None ,
for_inpainting : bool = False ,
) - > iconfig . ModelWeightsConfig :
2023-01-01 22:54:49 +00:00
""" Resolve weight and config path if they happen to be shortcuts. """
2023-12-11 16:39:09 +00:00
if isinstance ( model_weights , iconfig . ModelWeightsConfig ) :
return model_weights
if not isinstance ( model_weights , str ) :
msg = f " Invalid model weights: { model_weights } "
raise ValueError ( msg ) # noqa
if default_model_architecture is not None and not isinstance (
default_model_architecture , str
) :
msg = f " Invalid model architecture: { default_model_architecture } "
raise ValueError ( msg )
2023-12-08 04:57:55 +00:00
if for_inpainting :
model_weights_config = iconfig . MODEL_WEIGHT_CONFIG_LOOKUP . get (
f " { model_weights . lower ( ) } -inpaint " , None
)
if model_weights_config :
return model_weights_config
model_weights_config = iconfig . MODEL_WEIGHT_CONFIG_LOOKUP . get (
model_weights . lower ( ) , None
)
if model_weights_config :
return model_weights_config
if not default_model_architecture :
msg = " You must specify the model architecture when loading custom weights. "
raise ValueError ( msg )
default_model_architecture = default_model_architecture . lower ( )
model_architecture_config = None
if for_inpainting :
model_architecture_config = iconfig . MODEL_ARCHITECTURE_LOOKUP . get (
f " { default_model_architecture } -inpaint " , None
2023-01-01 22:54:49 +00:00
)
2023-12-08 04:57:55 +00:00
if not model_architecture_config :
model_architecture_config = iconfig . MODEL_ARCHITECTURE_LOOKUP . get (
default_model_architecture , None
)
if model_architecture_config is None :
msg = f " Invalid model architecture: { default_model_architecture } "
raise ValueError ( msg )
model_weights_config = iconfig . ModelWeightsConfig (
name = " Custom Loaded " ,
aliases = [ ] ,
architecture = model_architecture_config ,
weights_location = model_weights ,
defaults = { } ,
)
return model_weights_config
2023-12-12 06:29:36 +00:00
def get_model_default_image_size ( model_architecture : str | ModelArchitecture | None ) :
2023-12-08 04:57:55 +00:00
if isinstance ( model_architecture , str ) :
2023-12-12 06:29:36 +00:00
model_architecture = iconfig . MODEL_ARCHITECTURE_LOOKUP . get (
2023-12-08 04:57:55 +00:00
model_architecture , None
2023-01-01 22:54:49 +00:00
)
2023-12-08 04:57:55 +00:00
default_size = None
if model_architecture :
default_size = model_architecture . defaults . get ( " size " )
2023-01-01 22:54:49 +00:00
2023-12-08 04:57:55 +00:00
if default_size is None :
default_size = 512
2023-12-11 16:39:09 +00:00
default_size = normalize_image_size ( default_size )
2023-12-08 04:57:55 +00:00
return default_size
2022-11-26 22:52:28 +00:00
2022-10-23 21:46:45 +00:00
def get_current_diffusion_model ( ) :
return MOST_RECENTLY_LOADED_MODEL
def get_cache_dir ( ) :
xdg_cache_home = os . getenv ( " XDG_CACHE_HOME " , None )
if xdg_cache_home is None :
user_home = os . getenv ( " HOME " , None )
if user_home :
xdg_cache_home = os . path . join ( user_home , " .cache " )
if xdg_cache_home is not None :
2023-01-01 22:54:49 +00:00
return os . path . join ( xdg_cache_home , " imaginairy " )
2022-10-23 21:46:45 +00:00
2023-01-01 22:54:49 +00:00
return os . path . join ( os . path . dirname ( __file__ ) , " .cached-aimg " )
2022-10-23 21:46:45 +00:00
2023-01-01 22:54:49 +00:00
def get_cached_url_path ( url , category = None ) :
2022-10-23 21:46:45 +00:00
"""
2023-01-02 04:14:22 +00:00
Gets the contents of a url , but caches the response indefinitely .
2022-10-23 21:46:45 +00:00
While we attempt to use the cached_path from huggingface transformers , we fall back
to our own implementation if the url does not provide an etag header , which ` cached_path `
requires . We also skip the ` head ` call that ` cached_path ` makes on every call if the file
is already cached .
"""
try :
2023-01-09 04:45:58 +00:00
if url . startswith ( " https://huggingface.co " ) :
return huggingface_cached_path ( url )
2022-10-23 21:46:45 +00:00
except ( OSError , ValueError ) :
pass
filename = url . split ( " / " ) [ - 1 ]
dest = get_cache_dir ( )
2023-01-01 22:54:49 +00:00
if category :
dest = os . path . join ( dest , category )
2022-10-23 21:46:45 +00:00
os . makedirs ( dest , exist_ok = True )
2023-01-26 06:05:07 +00:00
# Replace possibly illegal destination path characters
safe_filename = re . sub ( ' [*<>: " |?] ' , " _ " , filename )
dest_path = os . path . join ( dest , safe_filename )
2022-10-23 21:46:45 +00:00
if os . path . exists ( dest_path ) :
return dest_path
2023-01-26 06:05:07 +00:00
# check if it's saved at previous path and rename it
old_dest_path = os . path . join ( dest , filename )
if os . path . exists ( old_dest_path ) :
os . rename ( old_dest_path , dest_path )
return dest_path
2023-09-29 08:13:50 +00:00
r = requests . get ( url )
2022-10-23 21:46:45 +00:00
with open ( dest_path , " wb " ) as f :
f . write ( r . content )
return dest_path
def check_huggingface_url_authorized ( url ) :
if not url . startswith ( " https://huggingface.co/ " ) :
return None
token = HfFolder . get_token ( )
headers = { }
if token is not None :
headers [ " authorization " ] = f " Bearer { token } "
response = requests . head ( url , allow_redirects = True , headers = headers , timeout = 5 )
if response . status_code == 401 :
2023-09-29 08:13:50 +00:00
msg = " Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information "
raise HuggingFaceAuthorizationError ( msg )
2022-10-23 21:46:45 +00:00
return None
2023-01-01 22:54:49 +00:00
@wraps ( _hf_hub_download )
def hf_hub_download ( * args , * * kwargs ) :
"""
backwards compatible wrapper for huggingface ' s hf_hub_download.
2023-01-22 18:16:17 +00:00
they changed the argument name from ` use_auth_token ` to ` token `
2023-01-01 22:54:49 +00:00
"""
2023-01-22 18:16:17 +00:00
try :
return _hf_hub_download ( * args , * * kwargs )
except TypeError as e :
if " unexpected keyword argument ' token ' " in str ( e ) :
kwargs [ " use_auth_token " ] = kwargs . pop ( " token " )
return _hf_hub_download ( * args , * * kwargs )
2023-09-29 08:13:50 +00:00
raise
2023-01-01 22:54:49 +00:00
2022-10-23 21:46:45 +00:00
def huggingface_cached_path ( url ) :
# bypass all the HEAD calls done by the default `cached_path`
2023-01-09 04:45:58 +00:00
repo , commit_hash , filepath = extract_huggingface_repo_commit_file_from_url ( url )
dest_path = try_to_load_from_cache (
repo_id = repo , revision = commit_hash , filename = filepath
)
2022-10-23 21:46:45 +00:00
if not dest_path :
check_huggingface_url_authorized ( url )
token = HfFolder . get_token ( )
2023-01-09 04:45:58 +00:00
logger . info ( f " Downloading { url } from huggingface " )
dest_path = hf_hub_download (
repo_id = repo , revision = commit_hash , filename = filepath , token = token
)
2023-01-24 04:16:47 +00:00
# make a refs folder so caching works
# work-around for
# https://github.com/huggingface/huggingface_hub/pull/1306
# https://github.com/brycedrennan/imaginAIry/issues/171
refs_url = dest_path [ : dest_path . index ( " /snapshots/ " ) ] + " /refs/ "
os . makedirs ( refs_url , exist_ok = True )
2022-10-23 21:46:45 +00:00
return dest_path
2023-01-09 04:45:58 +00:00
def extract_huggingface_repo_commit_file_from_url ( url ) :
parsed_url = urllib . parse . urlparse ( url )
path_components = parsed_url . path . strip ( " / " ) . split ( " / " )
repo = " / " . join ( path_components [ 0 : 2 ] )
assert path_components [ 2 ] == " resolve "
commit_hash = path_components [ 3 ]
filepath = " / " . join ( path_components [ 4 : ] )
return repo , commit_hash , filepath
2023-11-16 03:46:56 +00:00
2023-12-28 05:52:37 +00:00
def download_diffusers_weights ( base_url , sub , filename = None , prefer_fp16 = True ) :
2023-12-21 22:24:35 +00:00
if filename is None :
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url ( base_url )
fs = HfFileSystem ( )
filepaths = fs . ls (
f " { data [ ' author ' ] } / { data [ ' repo ' ] } / { sub } " , revision = data [ " ref " ] , detail = False
)
2023-12-28 05:52:37 +00:00
filepath = choose_diffusers_weights ( filepaths , prefer_fp16 = prefer_fp16 )
2023-12-21 22:24:35 +00:00
if not filepath :
msg = f " Could not find any weights in { base_url } / { sub } "
raise ValueError ( msg )
filename = filepath . split ( " / " ) [ - 1 ]
url = f " { base_url } { sub } / { filename } " . replace ( " /tree/ " , " /resolve/ " )
new_path = get_cached_url_path ( url , category = " weights " )
return new_path
2023-12-28 05:52:37 +00:00
def choose_diffusers_weights ( filenames , prefer_fp16 = True ) :
2023-12-21 22:24:35 +00:00
extension_priority = ( " .safetensors " , " .bin " , " .pth " , " .pt " )
# filter out any files that don't have a valid extension
filenames = [ f for f in filenames if any ( f . endswith ( e ) for e in extension_priority ) ]
filenames_and_extension = [ ( f , os . path . splitext ( f ) [ 1 ] ) for f in filenames ]
# sort by priority
2023-12-28 05:52:37 +00:00
if prefer_fp16 :
filenames_and_extension . sort (
key = lambda x : ( " fp16 " not in x [ 0 ] , extension_priority . index ( x [ 1 ] ) )
)
else :
filenames_and_extension . sort (
key = lambda x : ( " fp16 " in x [ 0 ] , extension_priority . index ( x [ 1 ] ) )
)
2023-12-21 22:24:35 +00:00
if filenames_and_extension :
return filenames_and_extension [ 0 ] [ 0 ]
return None
2023-11-16 03:46:56 +00:00
2023-12-28 05:52:37 +00:00
def load_sd15_diffusers_weights ( base_url : str , device = None ) :
2023-11-16 03:46:56 +00:00
from imaginairy . utils import get_device
from imaginairy . weight_management . conversion import cast_weights
from imaginairy . weight_management . utils import (
COMPONENT_NAMES ,
FORMAT_NAMES ,
MODEL_NAMES ,
)
2023-12-21 22:24:35 +00:00
base_url = normalize_diffusers_repo_url ( base_url )
2023-11-16 03:46:56 +00:00
if device is None :
device = get_device ( )
2023-12-21 22:24:35 +00:00
vae_weights_path = download_diffusers_weights ( base_url = base_url , sub = " vae " )
2023-11-16 03:46:56 +00:00
vae_weights = open_weights ( vae_weights_path , device = device )
vae_weights = cast_weights (
source_weights = vae_weights ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . VAE ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
2023-12-21 22:24:35 +00:00
unet_weights_path = download_diffusers_weights ( base_url = base_url , sub = " unet " )
2023-11-16 03:46:56 +00:00
unet_weights = open_weights ( unet_weights_path , device = device )
unet_weights = cast_weights (
source_weights = unet_weights ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . UNET ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
text_encoder_weights_path = download_diffusers_weights (
2023-12-21 22:24:35 +00:00
base_url = base_url , sub = " text_encoder "
2023-11-16 03:46:56 +00:00
)
text_encoder_weights = open_weights ( text_encoder_weights_path , device = device )
text_encoder_weights = cast_weights (
source_weights = text_encoder_weights ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . TEXT_ENCODER ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
2024-01-02 02:35:14 +00:00
first_vae = next ( iter ( vae_weights . values ( ) ) )
first_unet = next ( iter ( unet_weights . values ( ) ) )
first_encoder = next ( iter ( text_encoder_weights . values ( ) ) )
msg = (
f " vae weights. dtype: { first_vae . dtype } device: { first_vae . device } \n "
f " unet weights. dtype: { first_unet . dtype } device: { first_unet . device } \n "
f " text_encoder weights. dtype: { first_encoder . dtype } device: { first_encoder . device } \n "
)
logger . debug ( msg )
2023-11-16 03:46:56 +00:00
return vae_weights , unet_weights , text_encoder_weights
2024-01-13 21:43:15 +00:00
def load_sdxl_pipeline_from_diffusers_weights (
base_url : str , device = None , dtype = torch . float16
) :
2023-12-28 05:52:37 +00:00
from imaginairy . utils import get_device
device = device or get_device ( )
base_url = normalize_diffusers_repo_url ( base_url )
translator = translators . diffusers_autoencoder_kl_to_refiners_translator ( )
vae_weights_path = download_diffusers_weights (
base_url = base_url , sub = " vae " , prefer_fp16 = False
)
2023-12-29 17:04:33 +00:00
logger . debug ( f " vae: { vae_weights_path } " )
2023-12-28 05:52:37 +00:00
vae_weights = translator . load_and_translate_weights (
source_path = vae_weights_path ,
device = " cpu " ,
)
lda = SDXLAutoencoderSliced ( device = " cpu " , dtype = dtype )
2024-01-02 02:35:14 +00:00
lda . load_state_dict ( vae_weights , assign = True )
2023-12-28 05:52:37 +00:00
del vae_weights
translator = translators . diffusers_unet_sdxl_to_refiners_translator ( )
2023-12-29 17:04:33 +00:00
unet_weights_path = download_diffusers_weights (
base_url = base_url , sub = " unet " , prefer_fp16 = True
)
logger . debug ( f " unet: { unet_weights_path } " )
2023-12-28 05:52:37 +00:00
unet_weights = translator . load_and_translate_weights (
source_path = unet_weights_path ,
device = " cpu " ,
)
unet = SDXLUNet ( device = " cpu " , dtype = dtype , in_channels = 4 )
2024-01-02 02:35:14 +00:00
unet . load_state_dict ( unet_weights , assign = True )
2023-12-28 05:52:37 +00:00
del unet_weights
text_encoder_1_path = download_diffusers_weights (
base_url = base_url , sub = " text_encoder "
)
text_encoder_2_path = download_diffusers_weights (
base_url = base_url , sub = " text_encoder_2 "
)
2023-12-29 17:04:33 +00:00
logger . debug ( f " text encoder 1: { text_encoder_1_path } " )
logger . debug ( f " text encoder 2: { text_encoder_2_path } " )
2023-12-28 05:52:37 +00:00
text_encoder_weights = (
translators . DoubleTextEncoderTranslator ( ) . load_and_translate_weights (
text_encoder_l_weights_path = text_encoder_1_path ,
text_encoder_g_weights_path = text_encoder_2_path ,
device = " cpu " ,
)
)
2023-12-31 05:21:49 +00:00
text_encoder = DoubleTextEncoder ( device = " cpu " , dtype = torch . float32 )
2024-01-02 02:35:14 +00:00
text_encoder . load_state_dict ( text_encoder_weights , assign = True )
2023-12-28 05:52:37 +00:00
del text_encoder_weights
2023-12-31 05:21:49 +00:00
lda = lda . to ( device = device , dtype = torch . float32 )
2023-12-28 05:52:37 +00:00
unet = unet . to ( device = device )
text_encoder = text_encoder . to ( device = device )
sd = StableDiffusion_XL (
2023-12-31 05:21:49 +00:00
device = device , dtype = None , lda = lda , unet = unet , clip_text_encoder = text_encoder
2023-12-28 05:52:37 +00:00
)
return sd
2024-01-13 21:43:15 +00:00
def load_sdxl_pipeline_from_compvis_weights (
base_url : str , device = None , dtype = torch . float16
) :
from imaginairy . utils import get_device
2023-12-28 05:52:37 +00:00
device = device or get_device ( )
2024-01-13 21:43:15 +00:00
unet_weights , vae_weights , text_encoder_weights = load_sdxl_compvis_weights (
base_url
)
lda = SDXLAutoencoderSliced ( device = " cpu " , dtype = dtype )
lda . load_state_dict ( vae_weights , assign = True )
del vae_weights
unet = SDXLUNet ( device = " cpu " , dtype = dtype , in_channels = 4 )
unet . load_state_dict ( unet_weights , assign = True )
del unet_weights
text_encoder = DoubleTextEncoder ( device = " cpu " , dtype = torch . float32 )
text_encoder . load_state_dict ( text_encoder_weights , assign = True )
del text_encoder_weights
lda = lda . to ( device = device , dtype = torch . float32 )
unet = unet . to ( device = device )
text_encoder = text_encoder . to ( device = device )
sd = StableDiffusion_XL (
device = device , dtype = None , lda = lda , unet = unet , clip_text_encoder = text_encoder
)
2023-12-28 05:52:37 +00:00
return sd
2024-01-13 21:43:15 +00:00
def load_sdxl_pipeline ( base_url , device = None ) :
device = device or get_device ( )
with logger . timed_info ( f " Loaded SDXL pipeline from { base_url } " ) :
if is_diffusers_repo_url ( base_url ) :
sd = load_sdxl_pipeline_from_diffusers_weights ( base_url , device = device )
else :
sd = load_sdxl_pipeline_from_compvis_weights ( base_url , device = device )
return sd
2023-11-16 03:46:56 +00:00
def open_weights ( filepath , device = None ) :
from imaginairy . utils import get_device
if device is None :
device = get_device ( )
if " safetensor " in filepath . lower ( ) :
2024-01-03 05:06:39 +00:00
from imaginairy . vendored . refiners . fluxion . utils import safe_open
2023-11-16 03:46:56 +00:00
with safe_open ( path = filepath , framework = " pytorch " , device = device ) as tensors :
2023-11-23 18:16:12 +00:00
state_dict = {
2023-12-27 22:53:05 +00:00
key : tensors . get_tensor ( key )
for key in tensors . keys ( ) # noqa
2023-11-23 18:16:12 +00:00
}
2023-11-16 03:46:56 +00:00
else :
import torch
state_dict = torch . load ( filepath , map_location = device )
while " state_dict " in state_dict :
state_dict = state_dict [ " state_dict " ]
return state_dict
2023-12-28 05:52:37 +00:00
def load_tensors ( tensorfile , map_location = None ) :
if tensorfile == " empty " :
# used for testing
return { }
if tensorfile . endswith ( ( " .ckpt " , " .pth " , " .bin " ) ) :
return torch . load ( tensorfile , map_location = map_location )
if tensorfile . endswith ( " .safetensors " ) :
return load_file ( tensorfile , device = map_location )
return load_file ( tensorfile , device = map_location )
# raise ValueError(f"Unknown tensorfile type: {tensorfile}")
2023-11-16 03:46:56 +00:00
def load_stable_diffusion_compvis_weights ( weights_url ) :
from imaginairy . utils import get_device
from imaginairy . weight_management . conversion import cast_weights
from imaginairy . weight_management . utils import (
COMPONENT_NAMES ,
FORMAT_NAMES ,
MODEL_NAMES ,
)
weights_path = get_cached_url_path ( weights_url , category = " weights " )
logger . info ( f " Loading weights from { weights_path } " )
state_dict = open_weights ( weights_path , device = get_device ( ) )
text_encoder_prefix = " cond_stage_model. "
cut_start = len ( text_encoder_prefix )
text_encoder_state_dict = {
k [ cut_start : ] : v
for k , v in state_dict . items ( )
if k . startswith ( text_encoder_prefix )
}
text_encoder_state_dict = cast_weights (
source_weights = text_encoder_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . TEXT_ENCODER ,
source_format = FORMAT_NAMES . COMPVIS ,
dest_format = FORMAT_NAMES . DIFFUSERS ,
)
text_encoder_state_dict = cast_weights (
source_weights = text_encoder_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . TEXT_ENCODER ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
vae_prefix = " first_stage_model. "
cut_start = len ( vae_prefix )
vae_state_dict = {
k [ cut_start : ] : v for k , v in state_dict . items ( ) if k . startswith ( vae_prefix )
}
vae_state_dict = cast_weights (
source_weights = vae_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . VAE ,
source_format = FORMAT_NAMES . COMPVIS ,
dest_format = FORMAT_NAMES . DIFFUSERS ,
)
vae_state_dict = cast_weights (
source_weights = vae_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . VAE ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
unet_prefix = " model. "
cut_start = len ( unet_prefix )
unet_state_dict = {
k [ cut_start : ] : v for k , v in state_dict . items ( ) if k . startswith ( unet_prefix )
}
unet_state_dict = cast_weights (
source_weights = unet_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . UNET ,
source_format = FORMAT_NAMES . COMPVIS ,
dest_format = FORMAT_NAMES . DIFFUSERS ,
)
unet_state_dict = cast_weights (
source_weights = unet_state_dict ,
source_model_name = MODEL_NAMES . SD15 ,
source_component_name = COMPONENT_NAMES . UNET ,
source_format = FORMAT_NAMES . DIFFUSERS ,
dest_format = FORMAT_NAMES . REFINERS ,
)
return vae_state_dict , unet_state_dict , text_encoder_state_dict
2024-01-13 21:43:15 +00:00
def load_sdxl_compvis_weights ( url ) :
from safetensors import safe_open
weights_path = get_cached_url_path ( url )
state_dict = { }
unet_state_dict = { }
vae_state_dict = { }
text_encoder_1_state_dict = { }
text_encoder_2_state_dict = { }
with safe_open ( weights_path , framework = " pt " ) as f :
for key in f . keys ( ) : # noqa
if key . startswith ( " model.diffusion_model. " ) :
unet_state_dict [ key ] = f . get_tensor ( key )
elif key . startswith ( " first_stage_model " ) :
vae_state_dict [ key ] = f . get_tensor ( key )
elif key . startswith ( " conditioner.embedders.0. " ) :
text_encoder_1_state_dict [ key ] = f . get_tensor ( key )
elif key . startswith ( " conditioner.embedders.1. " ) :
text_encoder_2_state_dict [ key ] = f . get_tensor ( key )
else :
state_dict [ key ] = f . get_tensor ( key )
logger . warning ( f " Unused key { key } " )
unet_weightmap = load_weight_map ( " Compvis-UNet-SDXL-to-Diffusers " )
vae_weightmap = load_weight_map ( " Compvis-Autoencoder-SDXL-to-Diffusers " )
text_encoder_1_weightmap = load_weight_map ( " Compvis-TextEncoder-SDXL-to-Diffusers " )
text_encoder_2_weightmap = load_weight_map (
" Compvis-OpenClipTextEncoder-SDXL-to-Diffusers "
)
diffusers_unet_state_dict = unet_weightmap . translate_weights ( unet_state_dict )
refiners_unet_state_dict = (
diffusers_unet_sdxl_to_refiners_translator ( ) . translate_weights (
diffusers_unet_state_dict
)
)
diffusers_vae_state_dict = vae_weightmap . translate_weights ( vae_state_dict )
refiners_vae_state_dict = (
diffusers_autoencoder_kl_to_refiners_translator ( ) . translate_weights (
diffusers_vae_state_dict
)
)
diffusers_text_encoder_1_state_dict = text_encoder_1_weightmap . translate_weights (
text_encoder_1_state_dict
)
for key in list ( text_encoder_2_state_dict . keys ( ) ) :
if key . endswith ( ( " .in_proj_bias " , " .in_proj_weight " ) ) :
value = text_encoder_2_state_dict [ key ]
q , k , v = value . chunk ( 3 , dim = 0 )
text_encoder_2_state_dict [ f " { key } .0 " ] = q
text_encoder_2_state_dict [ f " { key } .1 " ] = k
text_encoder_2_state_dict [ f " { key } .2 " ] = v
del text_encoder_2_state_dict [ key ]
diffusers_text_encoder_2_state_dict = text_encoder_2_weightmap . translate_weights (
text_encoder_2_state_dict
)
refiners_text_encoder_weights = DoubleTextEncoderTranslator ( ) . translate_weights (
diffusers_text_encoder_1_state_dict , diffusers_text_encoder_2_state_dict
)
return (
refiners_unet_state_dict ,
refiners_vae_state_dict ,
refiners_text_encoder_weights ,
)