refactor: move download related functions to separate module (#453)

+ renames and typehints
pull/454/head
Bryce Drennan 5 months ago committed by GitHub
parent 502ffbdc63
commit 601a112dc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,6 +71,27 @@ Options:
### Whats New ### Whats New
[See full Changelog here](./docs/changelog.md) [See full Changelog here](./docs/changelog.md)
**14.1.0**
- 🎉 feature: make video generation smooth by adding frame interpolation
- feature: SDXL weights in the compvis format can now be used
- feature: allow video generation at any size specified by user
- feature: video generations output in "bounce" format
- feature: choose video output format: mp4, webp, or gif
- feature: fix random seed handling in video generation
- docs: auto-publish docs on push to master
- build: remove imageio dependency
- build: vendorize facexlib so we don't install its unneeded dependencies
**14.0.4**
- docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/
- build: remove fairscale dependency
- fix: video generation was broken
**14.0.3**
- fix: several critical bugs with package
- tests: add a wheel smoketest to detect these issues in the future
**14.0.0** **14.0.0**
- 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models) - 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models)
- add `--videogen` to any image generation to create a short video from the generated image - add `--videogen` to any image generation to create a short video from the generated image
@ -86,23 +107,7 @@ cutting edge features (SDXL, image prompts, etc) which will be added to imaginai
For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the same image for SD15 For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the same image for SD15
- 🎉 feature: loading diffusers based models now supported. Example `--model https://huggingface.co/ainz/diseny-pixar --model-architecture sd15` - 🎉 feature: loading diffusers based models now supported. Example `--model https://huggingface.co/ainz/diseny-pixar --model-architecture sd15`
- 🎉 feature: qrcode controlnet! - 🎉 feature: qrcode controlnet!
- feature: generate word images automatically. great for use with qrcode controlnet: `imagine "flowers" --gif --size hd --control-mode qrcode --control-image "textimg='JOY' font_color=white background_color=gray" -r 10`
- feature: opendalle 1.1 added. `--model opendalle` to use it
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
- feature: better logging output: color, error handling
- feature: support for pytorch 2.0
- feature: command line output significantly cleaned up and easier to read
- feature: adds --composition-strength parameter to cli (#416)
- performance: lower memory usage for upscaling
- performance: lower memory usage at startup
- performance: add sliced attention to several models (lowers memory use)
- fix: simpler memory management that avoids some of the previous bugs
- deprecated: support for python 3.8, 3.9
- deprecated: support for torch 1.13
- deprecated: support for Stable Diffusion versions 1.4, 2.0, and 2.1
- deprecated: image training
- broken: samplers other than ddim
### Run API server and StableStudio web interface (alpha) ### Run API server and StableStudio web interface (alpha)
Generate images via API or web interface. Much smaller featureset compared to the command line tool. Generate images via API or web interface. Much smaller featureset compared to the command line tool.

@ -13,11 +13,17 @@
- ✅ add type checker - ✅ add type checker
- ✅ add interface for loading diffusers weights - ✅ add interface for loading diffusers weights
- ✅ SDXL support - ✅ SDXL support
- sdxl inpainting - sdxl inpainting
- t2i adapters - t2i adapters
- image prompts
- embedding inputs - embedding inputs
- save complete metadata to image
- recreate image from metadata
- auto-incoporate https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
- only output the main image unless some flag is set - only output the main image unless some flag is set
- allow selection of output video format - ✅ allow selection of output video format
- test on python 3.11
- allow specification of filename format
- chain multiple operations together imggen => videogen - chain multiple operations together imggen => videogen
- https://github.com/pallets/click/tree/main/examples/imagepipe - https://github.com/pallets/click/tree/main/examples/imagepipe
@ -57,7 +63,7 @@
- ✅ set up ci (test/lint/format) - ✅ set up ci (test/lint/format)
- ✅ unified pipeline (txt2img & img2img combined) - ✅ unified pipeline (txt2img & img2img combined)
- ✅ setup parallel testing - ✅ setup parallel testing
- add docs - add docs
- 🚫 remove yaml config - 🚫 remove yaml config
- 🚫 delete more unused code - 🚫 delete more unused code
- faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9 - faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9

@ -28,7 +28,7 @@ from imaginairy.utils import (
platform_appropriate_autocast, platform_appropriate_autocast,
) )
from imaginairy.utils.animations import make_bounce_animation from imaginairy.utils.animations import make_bounce_animation
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT from imaginairy.utils.paths import PKG_ROOT

@ -9,7 +9,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
device = get_device() device = get_device()

@ -8,7 +8,7 @@ import torch
from PIL import Image from PIL import Image
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.vendored.basicsr.img_util import img2tensor, tensor2img from imaginairy.vendored.basicsr.img_util import img2tensor, tensor2img
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
from imaginairy.vendored.facexlib.utils.face_restoration_helper import FaceRestoreHelper from imaginairy.vendored.facexlib.utils.face_restoration_helper import FaceRestoreHelper

@ -5,8 +5,8 @@ import torch
from PIL import Image from PIL import Image
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.model_cache import memory_managed_model from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer from imaginairy.vendored.realesrgan import RealESRGANer

@ -8,8 +8,8 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from imaginairy.utils import get_device, platform_appropriate_autocast from imaginairy.utils import get_device, platform_appropriate_autocast
from imaginairy.utils.downloads import hf_hub_download
from imaginairy.utils.log_utils import log_latent from imaginairy.utils.log_utils import log_latent
from imaginairy.utils.model_manager import hf_hub_download
from imaginairy.vendored import k_diffusion as K from imaginairy.vendored import k_diffusion as K
from imaginairy.vendored.k_diffusion import layers from imaginairy.vendored.k_diffusion import layers
from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1 from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1

@ -15,7 +15,7 @@ from torch.nn import functional as F
from tqdm import tqdm from tqdm import tqdm
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
from .msssim import ssim_matlab from .msssim import ssim_matlab
from .RIFE_HDv3 import Model from .RIFE_HDv3 import Model

@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
class Network(torch.nn.Module): class Network(torch.nn.Module):

@ -11,8 +11,8 @@ from scipy.ndimage.filters import gaussian_filter
from torch import nn from torch import nn
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.img_utils import torch_image_to_openvcv_img from imaginairy.utils.img_utils import torch_image_to_openvcv_img
from imaginairy.utils.model_manager import get_cached_url_path
def pad_right_down_corner(img, stride, padValue): def pad_right_down_corner(img, stride, padValue):

@ -2,7 +2,7 @@
import torch import torch
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):

@ -0,0 +1,211 @@
import logging
import os
import re
import urllib.parse
from functools import lru_cache, wraps
import requests
from huggingface_hub import (
HfFileSystem,
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
logger = logging.getLogger(__name__)
def get_cached_url_path(url: str, category=None) -> str:
"""
Gets the contents of a url, but caches the response indefinitely.
While we attempt to use the cached_path from huggingface transformers, we fall back
to our own implementation if the url does not provide an etag header, which `cached_path`
requires. We also skip the `head` call that `cached_path` makes on every call if the file
is already cached.
"""
if url.startswith("https://huggingface.co"):
try:
return huggingface_cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
if category:
dest = os.path.join(dest, category)
os.makedirs(dest, exist_ok=True)
# Replace possibly illegal destination path characters
safe_filename = re.sub('[*<>:"|?]', "_", filename)
dest_path = os.path.join(dest, safe_filename)
if os.path.exists(dest_path):
return dest_path
# check if it's saved at previous path and rename it
old_dest_path = os.path.join(dest, filename)
if os.path.exists(old_dest_path):
os.rename(old_dest_path, dest_path)
return dest_path
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def check_huggingface_url_authorized(url: str) -> None:
if not url.startswith("https://huggingface.co/"):
return None
token = HfFolder.get_token()
headers = {}
if token is not None:
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
raise HuggingFaceAuthorizationError(msg)
return None
@wraps(_hf_hub_download)
def hf_hub_download(*args, **kwargs):
"""
backwards compatible wrapper for huggingface's hf_hub_download.
they changed the argument name from `use_auth_token` to `token`
"""
try:
return _hf_hub_download(*args, **kwargs)
except TypeError as e:
if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
raise
def huggingface_cached_path(url: str) -> str:
# bypass all the HEAD calls done by the default `cached_path`
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
dest_path = try_to_load_from_cache(
repo_id=repo, revision=commit_hash, filename=filepath
)
if not dest_path:
check_huggingface_url_authorized(url)
token = HfFolder.get_token()
logger.info(f"Downloading {url} from huggingface")
dest_path = hf_hub_download(
repo_id=repo, revision=commit_hash, filename=filepath, token=token
)
# make a refs folder so caching works
# work-around for
# https://github.com/huggingface/huggingface_hub/pull/1306
# https://github.com/brycedrennan/imaginAIry/issues/171
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
os.makedirs(refs_url, exist_ok=True)
return dest_path
def extract_huggingface_repo_commit_file_from_url(url):
parsed_url = urllib.parse.urlparse(url)
path_components = parsed_url.path.strip("/").split("/")
repo = "/".join(path_components[0:2])
assert path_components[2] == "resolve"
commit_hash = path_components[3]
filepath = "/".join(path_components[4:])
return repo, commit_hash, filepath
def download_huggingface_weights(
base_url: str, sub: str, filename=None, prefer_fp16=True
) -> str:
"""
Downloads weights from huggingface and returns the path to the downloaded file
Given a huggingface repo url, folder, and optional filename, download the weights to the cache directory and return the path
"""
if filename is None:
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url(base_url)
fs = HfFileSystem()
filepaths = fs.ls(
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
)
filepath = choose_huggingface_weights(filepaths, prefer_fp16=prefer_fp16)
if not filepath:
msg = f"Could not find any weights in {base_url}/{sub}"
raise ValueError(msg)
filename = filepath.split("/")[-1]
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
new_path = get_cached_url_path(url, category="weights")
return new_path
def choose_huggingface_weights(filenames: list[str], prefer_fp16=True) -> str | None:
"""
Chooses the best weights file from a list of filenames
Prefers safetensors format and fp16 dtype
"""
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
# filter out any files that don't have a valid extension
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
# sort by priority
if prefer_fp16:
filenames_and_extension.sort(
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
)
else:
filenames_and_extension.sort(
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
)
if filenames_and_extension:
return filenames_and_extension[0][0]
return None
@lru_cache
def get_cache_dir() -> str:
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy")
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
class HuggingFaceAuthorizationError(RuntimeError):
pass
hf_repo_url_pattern = re.compile(
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
)
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
match = hf_repo_url_pattern.match(url)
return match.groupdict() if match else {}
def is_diffusers_repo_url(url: str) -> bool:
result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str:
data = parse_diffusers_repo_url(url)
ref = data["ref"] or "main"
normalized_url = (
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
)
return normalized_url

@ -2,19 +2,10 @@
import logging import logging
import os import os
import re
import sys import sys
import urllib.parse from functools import lru_cache
from functools import lru_cache, wraps
import requests
import torch import torch
from huggingface_hub import (
HfFileSystem,
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
from omegaconf import OmegaConf from omegaconf import OmegaConf
from safetensors.torch import load_file from safetensors.torch import load_file
@ -27,6 +18,13 @@ from imaginairy.modules.refiners_sd import (
StableDiffusion_XL_Inpainting, StableDiffusion_XL_Inpainting,
) )
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
from imaginairy.utils.downloads import (
HuggingFaceAuthorizationError,
download_huggingface_weights,
get_cached_url_path,
is_diffusers_repo_url,
normalize_diffusers_repo_url,
)
from imaginairy.utils.model_cache import memory_managed_model from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT from imaginairy.utils.paths import PKG_ROOT
@ -54,10 +52,6 @@ logger = logging.getLogger(__name__)
MOST_RECENTLY_LOADED_MODEL = None MOST_RECENTLY_LOADED_MODEL = None
class HuggingFaceAuthorizationError(RuntimeError):
pass
def load_state_dict(weights_location, half_mode=False, device=None): def load_state_dict(weights_location, half_mode=False, device=None):
if device is None: if device is None:
device = get_device() device = get_device()
@ -118,13 +112,6 @@ def load_model_from_config(config, weights_location, half_mode=False):
return model return model
def add_controlnet(base_state_dict, controlnet_state_dict):
"""Merges a base sd15 model with a controlnet model."""
for key in controlnet_state_dict:
base_state_dict[key] = controlnet_state_dict[key]
return base_state_dict
def get_diffusion_model( def get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml", config_path="configs/stable-diffusion-v1.yaml",
@ -222,31 +209,6 @@ def get_diffusion_model_refiners(
return sd return sd
hf_repo_url_pattern = re.compile(
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
)
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
match = hf_repo_url_pattern.match(url)
return match.groupdict() if match else {}
def is_diffusers_repo_url(url: str) -> bool:
result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str:
data = parse_diffusers_repo_url(url)
ref = data["ref"] or "main"
normalized_url = (
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
)
return normalized_url
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def _get_diffusion_model_refiners( def _get_diffusion_model_refiners(
weights_location: str, weights_location: str,
@ -529,161 +491,6 @@ def get_current_diffusion_model():
return MOST_RECENTLY_LOADED_MODEL return MOST_RECENTLY_LOADED_MODEL
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy")
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
def get_cached_url_path(url, category=None):
"""
Gets the contents of a url, but caches the response indefinitely.
While we attempt to use the cached_path from huggingface transformers, we fall back
to our own implementation if the url does not provide an etag header, which `cached_path`
requires. We also skip the `head` call that `cached_path` makes on every call if the file
is already cached.
"""
try:
if url.startswith("https://huggingface.co"):
return huggingface_cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
if category:
dest = os.path.join(dest, category)
os.makedirs(dest, exist_ok=True)
# Replace possibly illegal destination path characters
safe_filename = re.sub('[*<>:"|?]', "_", filename)
dest_path = os.path.join(dest, safe_filename)
if os.path.exists(dest_path):
return dest_path
# check if it's saved at previous path and rename it
old_dest_path = os.path.join(dest, filename)
if os.path.exists(old_dest_path):
os.rename(old_dest_path, dest_path)
return dest_path
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def check_huggingface_url_authorized(url):
if not url.startswith("https://huggingface.co/"):
return None
token = HfFolder.get_token()
headers = {}
if token is not None:
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
raise HuggingFaceAuthorizationError(msg)
return None
@wraps(_hf_hub_download)
def hf_hub_download(*args, **kwargs):
"""
backwards compatible wrapper for huggingface's hf_hub_download.
they changed the argument name from `use_auth_token` to `token`
"""
try:
return _hf_hub_download(*args, **kwargs)
except TypeError as e:
if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
raise
def huggingface_cached_path(url):
# bypass all the HEAD calls done by the default `cached_path`
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
dest_path = try_to_load_from_cache(
repo_id=repo, revision=commit_hash, filename=filepath
)
if not dest_path:
check_huggingface_url_authorized(url)
token = HfFolder.get_token()
logger.info(f"Downloading {url} from huggingface")
dest_path = hf_hub_download(
repo_id=repo, revision=commit_hash, filename=filepath, token=token
)
# make a refs folder so caching works
# work-around for
# https://github.com/huggingface/huggingface_hub/pull/1306
# https://github.com/brycedrennan/imaginAIry/issues/171
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
os.makedirs(refs_url, exist_ok=True)
return dest_path
def extract_huggingface_repo_commit_file_from_url(url):
parsed_url = urllib.parse.urlparse(url)
path_components = parsed_url.path.strip("/").split("/")
repo = "/".join(path_components[0:2])
assert path_components[2] == "resolve"
commit_hash = path_components[3]
filepath = "/".join(path_components[4:])
return repo, commit_hash, filepath
def download_diffusers_weights(base_url, sub, filename=None, prefer_fp16=True):
if filename is None:
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url(base_url)
fs = HfFileSystem()
filepaths = fs.ls(
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
)
filepath = choose_diffusers_weights(filepaths, prefer_fp16=prefer_fp16)
if not filepath:
msg = f"Could not find any weights in {base_url}/{sub}"
raise ValueError(msg)
filename = filepath.split("/")[-1]
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
new_path = get_cached_url_path(url, category="weights")
return new_path
def choose_diffusers_weights(filenames, prefer_fp16=True):
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
# filter out any files that don't have a valid extension
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
# sort by priority
if prefer_fp16:
filenames_and_extension.sort(
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
)
else:
filenames_and_extension.sort(
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
)
if filenames_and_extension:
return filenames_and_extension[0][0]
return None
def load_sd15_diffusers_weights(base_url: str, device=None): def load_sd15_diffusers_weights(base_url: str, device=None):
from imaginairy.utils import get_device from imaginairy.utils import get_device
from imaginairy.weight_management.conversion import cast_weights from imaginairy.weight_management.conversion import cast_weights
@ -696,7 +503,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
base_url = normalize_diffusers_repo_url(base_url) base_url = normalize_diffusers_repo_url(base_url)
if device is None: if device is None:
device = get_device() device = get_device()
vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae") vae_weights_path = download_huggingface_weights(base_url=base_url, sub="vae")
vae_weights = open_weights(vae_weights_path, device=device) vae_weights = open_weights(vae_weights_path, device=device)
vae_weights = cast_weights( vae_weights = cast_weights(
source_weights=vae_weights, source_weights=vae_weights,
@ -706,7 +513,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
dest_format=FORMAT_NAMES.REFINERS, dest_format=FORMAT_NAMES.REFINERS,
) )
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet") unet_weights_path = download_huggingface_weights(base_url=base_url, sub="unet")
unet_weights = open_weights(unet_weights_path, device=device) unet_weights = open_weights(unet_weights_path, device=device)
unet_weights = cast_weights( unet_weights = cast_weights(
source_weights=unet_weights, source_weights=unet_weights,
@ -716,7 +523,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
dest_format=FORMAT_NAMES.REFINERS, dest_format=FORMAT_NAMES.REFINERS,
) )
text_encoder_weights_path = download_diffusers_weights( text_encoder_weights_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder" base_url=base_url, sub="text_encoder"
) )
text_encoder_weights = open_weights(text_encoder_weights_path, device=device) text_encoder_weights = open_weights(text_encoder_weights_path, device=device)
@ -749,7 +556,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
base_url = normalize_diffusers_repo_url(base_url) base_url = normalize_diffusers_repo_url(base_url)
translator = translators.diffusers_autoencoder_kl_to_refiners_translator() translator = translators.diffusers_autoencoder_kl_to_refiners_translator()
vae_weights_path = download_diffusers_weights( vae_weights_path = download_huggingface_weights(
base_url=base_url, sub="vae", prefer_fp16=False base_url=base_url, sub="vae", prefer_fp16=False
) )
logger.debug(f"vae: {vae_weights_path}") logger.debug(f"vae: {vae_weights_path}")
@ -762,7 +569,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
del vae_weights del vae_weights
translator = translators.diffusers_unet_sdxl_to_refiners_translator() translator = translators.diffusers_unet_sdxl_to_refiners_translator()
unet_weights_path = download_diffusers_weights( unet_weights_path = download_huggingface_weights(
base_url=base_url, sub="unet", prefer_fp16=True base_url=base_url, sub="unet", prefer_fp16=True
) )
logger.debug(f"unet: {unet_weights_path}") logger.debug(f"unet: {unet_weights_path}")
@ -777,10 +584,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
unet.load_state_dict(unet_weights, assign=True) unet.load_state_dict(unet_weights, assign=True)
del unet_weights del unet_weights
text_encoder_1_path = download_diffusers_weights( text_encoder_1_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder" base_url=base_url, sub="text_encoder"
) )
text_encoder_2_path = download_diffusers_weights( text_encoder_2_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder_2" base_url=base_url, sub="text_encoder_2"
) )
logger.debug(f"text encoder 1: {text_encoder_1_path}") logger.debug(f"text encoder 1: {text_encoder_1_path}")

@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

@ -2,8 +2,8 @@
import safetensors import safetensors
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.model_manager import ( from imaginairy.utils.model_manager import (
get_cached_url_path,
open_weights, open_weights,
resolve_model_weights_config, resolve_model_weights_config,
) )

@ -3,7 +3,7 @@ import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from imaginairy.utils.model_manager import get_cached_url_path from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.paths import PKG_ROOT from imaginairy.utils.paths import PKG_ROOT
sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt" sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt"

@ -1,8 +1,8 @@
import pytest import pytest
from imaginairy import config from imaginairy import config
from imaginairy.utils.downloads import parse_diffusers_repo_url
from imaginairy.utils.model_manager import ( from imaginairy.utils.model_manager import (
parse_diffusers_repo_url,
resolve_model_weights_config, resolve_model_weights_config,
) )

Loading…
Cancel
Save