From 601a112dc3c76054d9edd6ff0f5e58e721584526 Mon Sep 17 00:00:00 2001 From: Bryce Drennan Date: Sun, 14 Jan 2024 16:50:17 -0800 Subject: [PATCH] refactor: move download related functions to separate module (#453) + renames and typehints --- README.md | 39 +-- docs/todo.md | 12 +- imaginairy/api/video_sample.py | 2 +- imaginairy/enhancers/describe_image_blip.py | 2 +- .../enhancers/face_restoration_codeformer.py | 2 +- imaginairy/enhancers/upscale_realesrgan.py | 2 +- imaginairy/enhancers/upscale_riverwing.py | 2 +- .../video_interpolation/rife/interpolate.py | 2 +- imaginairy/img_processors/hed_boundary.py | 2 +- imaginairy/img_processors/openpose.py | 2 +- imaginairy/modules/midas/midas/base_model.py | 2 +- imaginairy/utils/downloads.py | 211 +++++++++++++++++ imaginairy/utils/model_manager.py | 223 ++---------------- imaginairy/vendored/realesrgan.py | 2 +- .../weight_management/generate_weight_info.py | 2 +- scripts/controlnet_convert.py | 2 +- tests/test_utils/test_model_manager.py | 2 +- 17 files changed, 270 insertions(+), 241 deletions(-) create mode 100644 imaginairy/utils/downloads.py diff --git a/README.md b/README.md index f5a326f..6355c75 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,27 @@ Options: ### Whats New [See full Changelog here](./docs/changelog.md) +**14.1.0** +- 🎉 feature: make video generation smooth by adding frame interpolation +- feature: SDXL weights in the compvis format can now be used +- feature: allow video generation at any size specified by user +- feature: video generations output in "bounce" format +- feature: choose video output format: mp4, webp, or gif +- feature: fix random seed handling in video generation +- docs: auto-publish docs on push to master +- build: remove imageio dependency +- build: vendorize facexlib so we don't install its unneeded dependencies + + +**14.0.4** +- docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/ +- build: remove fairscale dependency +- fix: video generation was broken + +**14.0.3** +- fix: several critical bugs with package +- tests: add a wheel smoketest to detect these issues in the future + **14.0.0** - 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models) - add `--videogen` to any image generation to create a short video from the generated image @@ -86,23 +107,7 @@ cutting edge features (SDXL, image prompts, etc) which will be added to imaginai For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the same image for SD15 - 🎉 feature: loading diffusers based models now supported. Example `--model https://huggingface.co/ainz/diseny-pixar --model-architecture sd15` - 🎉 feature: qrcode controlnet! -- feature: generate word images automatically. great for use with qrcode controlnet: `imagine "flowers" --gif --size hd --control-mode qrcode --control-image "textimg='JOY' font_color=white background_color=gray" -r 10` -- feature: opendalle 1.1 added. `--model opendalle` to use it -- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc) -- feature: detect if wrong torch version is installed and provide instructions on how to install proper version -- feature: better logging output: color, error handling -- feature: support for pytorch 2.0 -- feature: command line output significantly cleaned up and easier to read -- feature: adds --composition-strength parameter to cli (#416) -- performance: lower memory usage for upscaling -- performance: lower memory usage at startup -- performance: add sliced attention to several models (lowers memory use) -- fix: simpler memory management that avoids some of the previous bugs -- deprecated: support for python 3.8, 3.9 -- deprecated: support for torch 1.13 -- deprecated: support for Stable Diffusion versions 1.4, 2.0, and 2.1 -- deprecated: image training -- broken: samplers other than ddim + ### Run API server and StableStudio web interface (alpha) Generate images via API or web interface. Much smaller featureset compared to the command line tool. diff --git a/docs/todo.md b/docs/todo.md index 2e32d44..d842ece 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -13,11 +13,17 @@ - ✅ add type checker - ✅ add interface for loading diffusers weights - ✅ SDXL support - - sdxl inpainting + - ✅ sdxl inpainting - t2i adapters + - image prompts - embedding inputs + - save complete metadata to image + - recreate image from metadata + - auto-incoporate https://huggingface.co/madebyollin/sdxl-vae-fp16-fix - only output the main image unless some flag is set - - allow selection of output video format + - ✅ allow selection of output video format + - test on python 3.11 + - allow specification of filename format - chain multiple operations together imggen => videogen - https://github.com/pallets/click/tree/main/examples/imagepipe @@ -57,7 +63,7 @@ - ✅ set up ci (test/lint/format) - ✅ unified pipeline (txt2img & img2img combined) - ✅ setup parallel testing - - add docs + - ✅ add docs - 🚫 remove yaml config - 🚫 delete more unused code - faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9 diff --git a/imaginairy/api/video_sample.py b/imaginairy/api/video_sample.py index 77e4607..0d3e17a 100644 --- a/imaginairy/api/video_sample.py +++ b/imaginairy/api/video_sample.py @@ -28,7 +28,7 @@ from imaginairy.utils import ( platform_appropriate_autocast, ) from imaginairy.utils.animations import make_bounce_animation -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.paths import PKG_ROOT diff --git a/imaginairy/enhancers/describe_image_blip.py b/imaginairy/enhancers/describe_image_blip.py index dfff634..433981c 100644 --- a/imaginairy/enhancers/describe_image_blip.py +++ b/imaginairy/enhancers/describe_image_blip.py @@ -9,7 +9,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from imaginairy.utils import get_device -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint device = get_device() diff --git a/imaginairy/enhancers/face_restoration_codeformer.py b/imaginairy/enhancers/face_restoration_codeformer.py index fea0520..9557c8c 100644 --- a/imaginairy/enhancers/face_restoration_codeformer.py +++ b/imaginairy/enhancers/face_restoration_codeformer.py @@ -8,7 +8,7 @@ import torch from PIL import Image from torchvision.transforms.functional import normalize -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.vendored.basicsr.img_util import img2tensor, tensor2img from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer from imaginairy.vendored.facexlib.utils.face_restoration_helper import FaceRestoreHelper diff --git a/imaginairy/enhancers/upscale_realesrgan.py b/imaginairy/enhancers/upscale_realesrgan.py index ee97c66..aa1deaf 100644 --- a/imaginairy/enhancers/upscale_realesrgan.py +++ b/imaginairy/enhancers/upscale_realesrgan.py @@ -5,8 +5,8 @@ import torch from PIL import Image from imaginairy.utils import get_device +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.model_cache import memory_managed_model -from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet from imaginairy.vendored.realesrgan import RealESRGANer diff --git a/imaginairy/enhancers/upscale_riverwing.py b/imaginairy/enhancers/upscale_riverwing.py index 4f808a5..31a43b1 100644 --- a/imaginairy/enhancers/upscale_riverwing.py +++ b/imaginairy/enhancers/upscale_riverwing.py @@ -8,8 +8,8 @@ import torch.nn.functional as F from torch import nn from imaginairy.utils import get_device, platform_appropriate_autocast +from imaginairy.utils.downloads import hf_hub_download from imaginairy.utils.log_utils import log_latent -from imaginairy.utils.model_manager import hf_hub_download from imaginairy.vendored import k_diffusion as K from imaginairy.vendored.k_diffusion import layers from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1 diff --git a/imaginairy/enhancers/video_interpolation/rife/interpolate.py b/imaginairy/enhancers/video_interpolation/rife/interpolate.py index ef95ef4..08dbf8c 100644 --- a/imaginairy/enhancers/video_interpolation/rife/interpolate.py +++ b/imaginairy/enhancers/video_interpolation/rife/interpolate.py @@ -15,7 +15,7 @@ from torch.nn import functional as F from tqdm import tqdm from imaginairy.utils import get_device -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path from .msssim import ssim_matlab from .RIFE_HDv3 import Model diff --git a/imaginairy/img_processors/hed_boundary.py b/imaginairy/img_processors/hed_boundary.py index 190b871..a1d0b74 100644 --- a/imaginairy/img_processors/hed_boundary.py +++ b/imaginairy/img_processors/hed_boundary.py @@ -7,7 +7,7 @@ import numpy as np import torch from imaginairy.utils import get_device -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path class Network(torch.nn.Module): diff --git a/imaginairy/img_processors/openpose.py b/imaginairy/img_processors/openpose.py index 19febad..e4bdb7c 100644 --- a/imaginairy/img_processors/openpose.py +++ b/imaginairy/img_processors/openpose.py @@ -11,8 +11,8 @@ from scipy.ndimage.filters import gaussian_filter from torch import nn from imaginairy.utils import get_device +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.img_utils import torch_image_to_openvcv_img -from imaginairy.utils.model_manager import get_cached_url_path def pad_right_down_corner(img, stride, padValue): diff --git a/imaginairy/modules/midas/midas/base_model.py b/imaginairy/modules/midas/midas/base_model.py index 6bf1ec3..3ddb252 100644 --- a/imaginairy/modules/midas/midas/base_model.py +++ b/imaginairy/modules/midas/midas/base_model.py @@ -2,7 +2,7 @@ import torch -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path class BaseModel(torch.nn.Module): diff --git a/imaginairy/utils/downloads.py b/imaginairy/utils/downloads.py new file mode 100644 index 0000000..d048139 --- /dev/null +++ b/imaginairy/utils/downloads.py @@ -0,0 +1,211 @@ +import logging +import os +import re +import urllib.parse +from functools import lru_cache, wraps + +import requests +from huggingface_hub import ( + HfFileSystem, + HfFolder, + hf_hub_download as _hf_hub_download, + try_to_load_from_cache, +) + +logger = logging.getLogger(__name__) + + +def get_cached_url_path(url: str, category=None) -> str: + """ + Gets the contents of a url, but caches the response indefinitely. + + While we attempt to use the cached_path from huggingface transformers, we fall back + to our own implementation if the url does not provide an etag header, which `cached_path` + requires. We also skip the `head` call that `cached_path` makes on every call if the file + is already cached. + """ + if url.startswith("https://huggingface.co"): + try: + return huggingface_cached_path(url) + except (OSError, ValueError): + pass + filename = url.split("/")[-1] + dest = get_cache_dir() + if category: + dest = os.path.join(dest, category) + os.makedirs(dest, exist_ok=True) + + # Replace possibly illegal destination path characters + safe_filename = re.sub('[*<>:"|?]', "_", filename) + dest_path = os.path.join(dest, safe_filename) + if os.path.exists(dest_path): + return dest_path + + # check if it's saved at previous path and rename it + old_dest_path = os.path.join(dest, filename) + if os.path.exists(old_dest_path): + os.rename(old_dest_path, dest_path) + return dest_path + + r = requests.get(url) + + with open(dest_path, "wb") as f: + f.write(r.content) + return dest_path + + +def check_huggingface_url_authorized(url: str) -> None: + if not url.startswith("https://huggingface.co/"): + return None + token = HfFolder.get_token() + headers = {} + if token is not None: + headers["authorization"] = f"Bearer {token}" + response = requests.head(url, allow_redirects=True, headers=headers, timeout=5) + if response.status_code == 401: + msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information" + raise HuggingFaceAuthorizationError(msg) + return None + + +@wraps(_hf_hub_download) +def hf_hub_download(*args, **kwargs): + """ + backwards compatible wrapper for huggingface's hf_hub_download. + + they changed the argument name from `use_auth_token` to `token` + """ + + try: + return _hf_hub_download(*args, **kwargs) + except TypeError as e: + if "unexpected keyword argument 'token'" in str(e): + kwargs["use_auth_token"] = kwargs.pop("token") + return _hf_hub_download(*args, **kwargs) + raise + + +def huggingface_cached_path(url: str) -> str: + # bypass all the HEAD calls done by the default `cached_path` + repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url) + dest_path = try_to_load_from_cache( + repo_id=repo, revision=commit_hash, filename=filepath + ) + if not dest_path: + check_huggingface_url_authorized(url) + token = HfFolder.get_token() + logger.info(f"Downloading {url} from huggingface") + dest_path = hf_hub_download( + repo_id=repo, revision=commit_hash, filename=filepath, token=token + ) + # make a refs folder so caching works + # work-around for + # https://github.com/huggingface/huggingface_hub/pull/1306 + # https://github.com/brycedrennan/imaginAIry/issues/171 + refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/" + os.makedirs(refs_url, exist_ok=True) + return dest_path + + +def extract_huggingface_repo_commit_file_from_url(url): + parsed_url = urllib.parse.urlparse(url) + path_components = parsed_url.path.strip("/").split("/") + + repo = "/".join(path_components[0:2]) + assert path_components[2] == "resolve" + commit_hash = path_components[3] + filepath = "/".join(path_components[4:]) + + return repo, commit_hash, filepath + + +def download_huggingface_weights( + base_url: str, sub: str, filename=None, prefer_fp16=True +) -> str: + """ + Downloads weights from huggingface and returns the path to the downloaded file + + Given a huggingface repo url, folder, and optional filename, download the weights to the cache directory and return the path + """ + if filename is None: + # select which weights to download. prefer fp16 safetensors + data = parse_diffusers_repo_url(base_url) + fs = HfFileSystem() + filepaths = fs.ls( + f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False + ) + filepath = choose_huggingface_weights(filepaths, prefer_fp16=prefer_fp16) + if not filepath: + msg = f"Could not find any weights in {base_url}/{sub}" + raise ValueError(msg) + filename = filepath.split("/")[-1] + url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/") + new_path = get_cached_url_path(url, category="weights") + return new_path + + +def choose_huggingface_weights(filenames: list[str], prefer_fp16=True) -> str | None: + """ + Chooses the best weights file from a list of filenames + + Prefers safetensors format and fp16 dtype + """ + extension_priority = (".safetensors", ".bin", ".pth", ".pt") + # filter out any files that don't have a valid extension + filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)] + filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames] + # sort by priority + if prefer_fp16: + filenames_and_extension.sort( + key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1])) + ) + else: + filenames_and_extension.sort( + key=lambda x: ("fp16" in x[0], extension_priority.index(x[1])) + ) + if filenames_and_extension: + return filenames_and_extension[0][0] + return None + + +@lru_cache +def get_cache_dir() -> str: + xdg_cache_home = os.getenv("XDG_CACHE_HOME", None) + if xdg_cache_home is None: + user_home = os.getenv("HOME", None) + if user_home: + xdg_cache_home = os.path.join(user_home, ".cache") + + if xdg_cache_home is not None: + return os.path.join(xdg_cache_home, "imaginairy") + + return os.path.join(os.path.dirname(__file__), ".cached-aimg") + + +class HuggingFaceAuthorizationError(RuntimeError): + pass + + +hf_repo_url_pattern = re.compile( + r"https://huggingface\.co/(?P[^/]+)/(?P[^/]+)(/tree/(?P[a-z0-9]+))?/?$" +) + + +def parse_diffusers_repo_url(url: str) -> dict[str, str]: + match = hf_repo_url_pattern.match(url) + return match.groupdict() if match else {} + + +def is_diffusers_repo_url(url: str) -> bool: + result = bool(parse_diffusers_repo_url(url)) + logger.debug(f"{url} is diffusers repo url: {result}") + return result + + +def normalize_diffusers_repo_url(url: str) -> str: + data = parse_diffusers_repo_url(url) + ref = data["ref"] or "main" + normalized_url = ( + f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/" + ) + return normalized_url diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index 4c9dd4a..bc9581a 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -2,19 +2,10 @@ import logging import os -import re import sys -import urllib.parse -from functools import lru_cache, wraps +from functools import lru_cache -import requests import torch -from huggingface_hub import ( - HfFileSystem, - HfFolder, - hf_hub_download as _hf_hub_download, - try_to_load_from_cache, -) from omegaconf import OmegaConf from safetensors.torch import load_file @@ -27,6 +18,13 @@ from imaginairy.modules.refiners_sd import ( StableDiffusion_XL_Inpainting, ) from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config +from imaginairy.utils.downloads import ( + HuggingFaceAuthorizationError, + download_huggingface_weights, + get_cached_url_path, + is_diffusers_repo_url, + normalize_diffusers_repo_url, +) from imaginairy.utils.model_cache import memory_managed_model from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.paths import PKG_ROOT @@ -54,10 +52,6 @@ logger = logging.getLogger(__name__) MOST_RECENTLY_LOADED_MODEL = None -class HuggingFaceAuthorizationError(RuntimeError): - pass - - def load_state_dict(weights_location, half_mode=False, device=None): if device is None: device = get_device() @@ -118,13 +112,6 @@ def load_model_from_config(config, weights_location, half_mode=False): return model -def add_controlnet(base_state_dict, controlnet_state_dict): - """Merges a base sd15 model with a controlnet model.""" - for key in controlnet_state_dict: - base_state_dict[key] = controlnet_state_dict[key] - return base_state_dict - - def get_diffusion_model( weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, config_path="configs/stable-diffusion-v1.yaml", @@ -222,31 +209,6 @@ def get_diffusion_model_refiners( return sd -hf_repo_url_pattern = re.compile( - r"https://huggingface\.co/(?P[^/]+)/(?P[^/]+)(/tree/(?P[a-z0-9]+))?/?$" -) - - -def parse_diffusers_repo_url(url: str) -> dict[str, str]: - match = hf_repo_url_pattern.match(url) - return match.groupdict() if match else {} - - -def is_diffusers_repo_url(url: str) -> bool: - result = bool(parse_diffusers_repo_url(url)) - logger.debug(f"{url} is diffusers repo url: {result}") - return result - - -def normalize_diffusers_repo_url(url: str) -> str: - data = parse_diffusers_repo_url(url) - ref = data["ref"] or "main" - normalized_url = ( - f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/" - ) - return normalized_url - - @lru_cache(maxsize=1) def _get_diffusion_model_refiners( weights_location: str, @@ -529,161 +491,6 @@ def get_current_diffusion_model(): return MOST_RECENTLY_LOADED_MODEL -def get_cache_dir(): - xdg_cache_home = os.getenv("XDG_CACHE_HOME", None) - if xdg_cache_home is None: - user_home = os.getenv("HOME", None) - if user_home: - xdg_cache_home = os.path.join(user_home, ".cache") - - if xdg_cache_home is not None: - return os.path.join(xdg_cache_home, "imaginairy") - - return os.path.join(os.path.dirname(__file__), ".cached-aimg") - - -def get_cached_url_path(url, category=None): - """ - Gets the contents of a url, but caches the response indefinitely. - - While we attempt to use the cached_path from huggingface transformers, we fall back - to our own implementation if the url does not provide an etag header, which `cached_path` - requires. We also skip the `head` call that `cached_path` makes on every call if the file - is already cached. - """ - - try: - if url.startswith("https://huggingface.co"): - return huggingface_cached_path(url) - except (OSError, ValueError): - pass - filename = url.split("/")[-1] - dest = get_cache_dir() - if category: - dest = os.path.join(dest, category) - os.makedirs(dest, exist_ok=True) - - # Replace possibly illegal destination path characters - safe_filename = re.sub('[*<>:"|?]', "_", filename) - dest_path = os.path.join(dest, safe_filename) - if os.path.exists(dest_path): - return dest_path - - # check if it's saved at previous path and rename it - old_dest_path = os.path.join(dest, filename) - if os.path.exists(old_dest_path): - os.rename(old_dest_path, dest_path) - return dest_path - - r = requests.get(url) - - with open(dest_path, "wb") as f: - f.write(r.content) - return dest_path - - -def check_huggingface_url_authorized(url): - if not url.startswith("https://huggingface.co/"): - return None - token = HfFolder.get_token() - headers = {} - if token is not None: - headers["authorization"] = f"Bearer {token}" - response = requests.head(url, allow_redirects=True, headers=headers, timeout=5) - if response.status_code == 401: - msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information" - raise HuggingFaceAuthorizationError(msg) - return None - - -@wraps(_hf_hub_download) -def hf_hub_download(*args, **kwargs): - """ - backwards compatible wrapper for huggingface's hf_hub_download. - - they changed the argument name from `use_auth_token` to `token` - """ - - try: - return _hf_hub_download(*args, **kwargs) - except TypeError as e: - if "unexpected keyword argument 'token'" in str(e): - kwargs["use_auth_token"] = kwargs.pop("token") - return _hf_hub_download(*args, **kwargs) - raise - - -def huggingface_cached_path(url): - # bypass all the HEAD calls done by the default `cached_path` - repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url) - dest_path = try_to_load_from_cache( - repo_id=repo, revision=commit_hash, filename=filepath - ) - if not dest_path: - check_huggingface_url_authorized(url) - token = HfFolder.get_token() - logger.info(f"Downloading {url} from huggingface") - dest_path = hf_hub_download( - repo_id=repo, revision=commit_hash, filename=filepath, token=token - ) - # make a refs folder so caching works - # work-around for - # https://github.com/huggingface/huggingface_hub/pull/1306 - # https://github.com/brycedrennan/imaginAIry/issues/171 - refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/" - os.makedirs(refs_url, exist_ok=True) - return dest_path - - -def extract_huggingface_repo_commit_file_from_url(url): - parsed_url = urllib.parse.urlparse(url) - path_components = parsed_url.path.strip("/").split("/") - - repo = "/".join(path_components[0:2]) - assert path_components[2] == "resolve" - commit_hash = path_components[3] - filepath = "/".join(path_components[4:]) - - return repo, commit_hash, filepath - - -def download_diffusers_weights(base_url, sub, filename=None, prefer_fp16=True): - if filename is None: - # select which weights to download. prefer fp16 safetensors - data = parse_diffusers_repo_url(base_url) - fs = HfFileSystem() - filepaths = fs.ls( - f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False - ) - filepath = choose_diffusers_weights(filepaths, prefer_fp16=prefer_fp16) - if not filepath: - msg = f"Could not find any weights in {base_url}/{sub}" - raise ValueError(msg) - filename = filepath.split("/")[-1] - url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/") - new_path = get_cached_url_path(url, category="weights") - return new_path - - -def choose_diffusers_weights(filenames, prefer_fp16=True): - extension_priority = (".safetensors", ".bin", ".pth", ".pt") - # filter out any files that don't have a valid extension - filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)] - filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames] - # sort by priority - if prefer_fp16: - filenames_and_extension.sort( - key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1])) - ) - else: - filenames_and_extension.sort( - key=lambda x: ("fp16" in x[0], extension_priority.index(x[1])) - ) - if filenames_and_extension: - return filenames_and_extension[0][0] - return None - - def load_sd15_diffusers_weights(base_url: str, device=None): from imaginairy.utils import get_device from imaginairy.weight_management.conversion import cast_weights @@ -696,7 +503,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None): base_url = normalize_diffusers_repo_url(base_url) if device is None: device = get_device() - vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae") + vae_weights_path = download_huggingface_weights(base_url=base_url, sub="vae") vae_weights = open_weights(vae_weights_path, device=device) vae_weights = cast_weights( source_weights=vae_weights, @@ -706,7 +513,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None): dest_format=FORMAT_NAMES.REFINERS, ) - unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet") + unet_weights_path = download_huggingface_weights(base_url=base_url, sub="unet") unet_weights = open_weights(unet_weights_path, device=device) unet_weights = cast_weights( source_weights=unet_weights, @@ -716,7 +523,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None): dest_format=FORMAT_NAMES.REFINERS, ) - text_encoder_weights_path = download_diffusers_weights( + text_encoder_weights_path = download_huggingface_weights( base_url=base_url, sub="text_encoder" ) text_encoder_weights = open_weights(text_encoder_weights_path, device=device) @@ -749,7 +556,7 @@ def load_sdxl_pipeline_from_diffusers_weights( base_url = normalize_diffusers_repo_url(base_url) translator = translators.diffusers_autoencoder_kl_to_refiners_translator() - vae_weights_path = download_diffusers_weights( + vae_weights_path = download_huggingface_weights( base_url=base_url, sub="vae", prefer_fp16=False ) logger.debug(f"vae: {vae_weights_path}") @@ -762,7 +569,7 @@ def load_sdxl_pipeline_from_diffusers_weights( del vae_weights translator = translators.diffusers_unet_sdxl_to_refiners_translator() - unet_weights_path = download_diffusers_weights( + unet_weights_path = download_huggingface_weights( base_url=base_url, sub="unet", prefer_fp16=True ) logger.debug(f"unet: {unet_weights_path}") @@ -777,10 +584,10 @@ def load_sdxl_pipeline_from_diffusers_weights( unet.load_state_dict(unet_weights, assign=True) del unet_weights - text_encoder_1_path = download_diffusers_weights( + text_encoder_1_path = download_huggingface_weights( base_url=base_url, sub="text_encoder" ) - text_encoder_2_path = download_diffusers_weights( + text_encoder_2_path = download_huggingface_weights( base_url=base_url, sub="text_encoder_2" ) logger.debug(f"text encoder 1: {text_encoder_1_path}") diff --git a/imaginairy/vendored/realesrgan.py b/imaginairy/vendored/realesrgan.py index 94b119e..0438260 100644 --- a/imaginairy/vendored/realesrgan.py +++ b/imaginairy/vendored/realesrgan.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch.nn import functional as F -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) diff --git a/imaginairy/weight_management/generate_weight_info.py b/imaginairy/weight_management/generate_weight_info.py index 927fdf8..22d9276 100644 --- a/imaginairy/weight_management/generate_weight_info.py +++ b/imaginairy/weight_management/generate_weight_info.py @@ -2,8 +2,8 @@ import safetensors +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.model_manager import ( - get_cached_url_path, open_weights, resolve_model_weights_config, ) diff --git a/scripts/controlnet_convert.py b/scripts/controlnet_convert.py index d3d7a14..e8692c6 100644 --- a/scripts/controlnet_convert.py +++ b/scripts/controlnet_convert.py @@ -3,7 +3,7 @@ import os import torch from safetensors.torch import load_file, save_file -from imaginairy.utils.model_manager import get_cached_url_path +from imaginairy.utils.downloads import get_cached_url_path from imaginairy.utils.paths import PKG_ROOT sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt" diff --git a/tests/test_utils/test_model_manager.py b/tests/test_utils/test_model_manager.py index 18566b0..6052917 100644 --- a/tests/test_utils/test_model_manager.py +++ b/tests/test_utils/test_model_manager.py @@ -1,8 +1,8 @@ import pytest from imaginairy import config +from imaginairy.utils.downloads import parse_diffusers_repo_url from imaginairy.utils.model_manager import ( - parse_diffusers_repo_url, resolve_model_weights_config, )