mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
refactor: move download related functions to separate module (#453)
+ renames and typehints
This commit is contained in:
parent
502ffbdc63
commit
601a112dc3
39
README.md
39
README.md
@ -71,6 +71,27 @@ Options:
|
||||
### Whats New
|
||||
[See full Changelog here](./docs/changelog.md)
|
||||
|
||||
**14.1.0**
|
||||
- 🎉 feature: make video generation smooth by adding frame interpolation
|
||||
- feature: SDXL weights in the compvis format can now be used
|
||||
- feature: allow video generation at any size specified by user
|
||||
- feature: video generations output in "bounce" format
|
||||
- feature: choose video output format: mp4, webp, or gif
|
||||
- feature: fix random seed handling in video generation
|
||||
- docs: auto-publish docs on push to master
|
||||
- build: remove imageio dependency
|
||||
- build: vendorize facexlib so we don't install its unneeded dependencies
|
||||
|
||||
|
||||
**14.0.4**
|
||||
- docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/
|
||||
- build: remove fairscale dependency
|
||||
- fix: video generation was broken
|
||||
|
||||
**14.0.3**
|
||||
- fix: several critical bugs with package
|
||||
- tests: add a wheel smoketest to detect these issues in the future
|
||||
|
||||
**14.0.0**
|
||||
- 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models)
|
||||
- add `--videogen` to any image generation to create a short video from the generated image
|
||||
@ -86,23 +107,7 @@ cutting edge features (SDXL, image prompts, etc) which will be added to imaginai
|
||||
For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the same image for SD15
|
||||
- 🎉 feature: loading diffusers based models now supported. Example `--model https://huggingface.co/ainz/diseny-pixar --model-architecture sd15`
|
||||
- 🎉 feature: qrcode controlnet!
|
||||
- feature: generate word images automatically. great for use with qrcode controlnet: `imagine "flowers" --gif --size hd --control-mode qrcode --control-image "textimg='JOY' font_color=white background_color=gray" -r 10`
|
||||
- feature: opendalle 1.1 added. `--model opendalle` to use it
|
||||
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
|
||||
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
|
||||
- feature: better logging output: color, error handling
|
||||
- feature: support for pytorch 2.0
|
||||
- feature: command line output significantly cleaned up and easier to read
|
||||
- feature: adds --composition-strength parameter to cli (#416)
|
||||
- performance: lower memory usage for upscaling
|
||||
- performance: lower memory usage at startup
|
||||
- performance: add sliced attention to several models (lowers memory use)
|
||||
- fix: simpler memory management that avoids some of the previous bugs
|
||||
- deprecated: support for python 3.8, 3.9
|
||||
- deprecated: support for torch 1.13
|
||||
- deprecated: support for Stable Diffusion versions 1.4, 2.0, and 2.1
|
||||
- deprecated: image training
|
||||
- broken: samplers other than ddim
|
||||
|
||||
|
||||
### Run API server and StableStudio web interface (alpha)
|
||||
Generate images via API or web interface. Much smaller featureset compared to the command line tool.
|
||||
|
12
docs/todo.md
12
docs/todo.md
@ -13,11 +13,17 @@
|
||||
- ✅ add type checker
|
||||
- ✅ add interface for loading diffusers weights
|
||||
- ✅ SDXL support
|
||||
- sdxl inpainting
|
||||
- ✅ sdxl inpainting
|
||||
- t2i adapters
|
||||
- image prompts
|
||||
- embedding inputs
|
||||
- save complete metadata to image
|
||||
- recreate image from metadata
|
||||
- auto-incoporate https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
||||
- only output the main image unless some flag is set
|
||||
- allow selection of output video format
|
||||
- ✅ allow selection of output video format
|
||||
- test on python 3.11
|
||||
- allow specification of filename format
|
||||
- chain multiple operations together imggen => videogen
|
||||
- https://github.com/pallets/click/tree/main/examples/imagepipe
|
||||
|
||||
@ -57,7 +63,7 @@
|
||||
- ✅ set up ci (test/lint/format)
|
||||
- ✅ unified pipeline (txt2img & img2img combined)
|
||||
- ✅ setup parallel testing
|
||||
- add docs
|
||||
- ✅ add docs
|
||||
- 🚫 remove yaml config
|
||||
- 🚫 delete more unused code
|
||||
- faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9
|
||||
|
@ -28,7 +28,7 @@ from imaginairy.utils import (
|
||||
platform_appropriate_autocast,
|
||||
)
|
||||
from imaginairy.utils.animations import make_bounce_animation
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||
from imaginairy.utils.paths import PKG_ROOT
|
||||
|
||||
|
@ -9,7 +9,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
|
||||
|
||||
device = get_device()
|
||||
|
@ -8,7 +8,7 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import normalize
|
||||
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.vendored.basicsr.img_util import img2tensor, tensor2img
|
||||
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
|
||||
from imaginairy.vendored.facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
@ -5,8 +5,8 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.utils.model_cache import memory_managed_model
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
|
||||
from imaginairy.vendored.realesrgan import RealESRGANer
|
||||
|
||||
|
@ -8,8 +8,8 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.utils import get_device, platform_appropriate_autocast
|
||||
from imaginairy.utils.downloads import hf_hub_download
|
||||
from imaginairy.utils.log_utils import log_latent
|
||||
from imaginairy.utils.model_manager import hf_hub_download
|
||||
from imaginairy.vendored import k_diffusion as K
|
||||
from imaginairy.vendored.k_diffusion import layers
|
||||
from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1
|
||||
|
@ -15,7 +15,7 @@ from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
|
||||
from .msssim import ssim_matlab
|
||||
from .RIFE_HDv3 import Model
|
||||
|
@ -7,7 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
|
||||
|
||||
class Network(torch.nn.Module):
|
||||
|
@ -11,8 +11,8 @@ from scipy.ndimage.filters import gaussian_filter
|
||||
from torch import nn
|
||||
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.utils.img_utils import torch_image_to_openvcv_img
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
|
||||
|
||||
def pad_right_down_corner(img, stride, padValue):
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
|
211
imaginairy/utils/downloads.py
Normal file
211
imaginairy/utils/downloads.py
Normal file
@ -0,0 +1,211 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from functools import lru_cache, wraps
|
||||
|
||||
import requests
|
||||
from huggingface_hub import (
|
||||
HfFileSystem,
|
||||
HfFolder,
|
||||
hf_hub_download as _hf_hub_download,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cached_url_path(url: str, category=None) -> str:
|
||||
"""
|
||||
Gets the contents of a url, but caches the response indefinitely.
|
||||
|
||||
While we attempt to use the cached_path from huggingface transformers, we fall back
|
||||
to our own implementation if the url does not provide an etag header, which `cached_path`
|
||||
requires. We also skip the `head` call that `cached_path` makes on every call if the file
|
||||
is already cached.
|
||||
"""
|
||||
if url.startswith("https://huggingface.co"):
|
||||
try:
|
||||
return huggingface_cached_path(url)
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
filename = url.split("/")[-1]
|
||||
dest = get_cache_dir()
|
||||
if category:
|
||||
dest = os.path.join(dest, category)
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
||||
# Replace possibly illegal destination path characters
|
||||
safe_filename = re.sub('[*<>:"|?]', "_", filename)
|
||||
dest_path = os.path.join(dest, safe_filename)
|
||||
if os.path.exists(dest_path):
|
||||
return dest_path
|
||||
|
||||
# check if it's saved at previous path and rename it
|
||||
old_dest_path = os.path.join(dest, filename)
|
||||
if os.path.exists(old_dest_path):
|
||||
os.rename(old_dest_path, dest_path)
|
||||
return dest_path
|
||||
|
||||
r = requests.get(url)
|
||||
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(r.content)
|
||||
return dest_path
|
||||
|
||||
|
||||
def check_huggingface_url_authorized(url: str) -> None:
|
||||
if not url.startswith("https://huggingface.co/"):
|
||||
return None
|
||||
token = HfFolder.get_token()
|
||||
headers = {}
|
||||
if token is not None:
|
||||
headers["authorization"] = f"Bearer {token}"
|
||||
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
|
||||
if response.status_code == 401:
|
||||
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
|
||||
raise HuggingFaceAuthorizationError(msg)
|
||||
return None
|
||||
|
||||
|
||||
@wraps(_hf_hub_download)
|
||||
def hf_hub_download(*args, **kwargs):
|
||||
"""
|
||||
backwards compatible wrapper for huggingface's hf_hub_download.
|
||||
|
||||
they changed the argument name from `use_auth_token` to `token`
|
||||
"""
|
||||
|
||||
try:
|
||||
return _hf_hub_download(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument 'token'" in str(e):
|
||||
kwargs["use_auth_token"] = kwargs.pop("token")
|
||||
return _hf_hub_download(*args, **kwargs)
|
||||
raise
|
||||
|
||||
|
||||
def huggingface_cached_path(url: str) -> str:
|
||||
# bypass all the HEAD calls done by the default `cached_path`
|
||||
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
|
||||
dest_path = try_to_load_from_cache(
|
||||
repo_id=repo, revision=commit_hash, filename=filepath
|
||||
)
|
||||
if not dest_path:
|
||||
check_huggingface_url_authorized(url)
|
||||
token = HfFolder.get_token()
|
||||
logger.info(f"Downloading {url} from huggingface")
|
||||
dest_path = hf_hub_download(
|
||||
repo_id=repo, revision=commit_hash, filename=filepath, token=token
|
||||
)
|
||||
# make a refs folder so caching works
|
||||
# work-around for
|
||||
# https://github.com/huggingface/huggingface_hub/pull/1306
|
||||
# https://github.com/brycedrennan/imaginAIry/issues/171
|
||||
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
|
||||
os.makedirs(refs_url, exist_ok=True)
|
||||
return dest_path
|
||||
|
||||
|
||||
def extract_huggingface_repo_commit_file_from_url(url):
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
path_components = parsed_url.path.strip("/").split("/")
|
||||
|
||||
repo = "/".join(path_components[0:2])
|
||||
assert path_components[2] == "resolve"
|
||||
commit_hash = path_components[3]
|
||||
filepath = "/".join(path_components[4:])
|
||||
|
||||
return repo, commit_hash, filepath
|
||||
|
||||
|
||||
def download_huggingface_weights(
|
||||
base_url: str, sub: str, filename=None, prefer_fp16=True
|
||||
) -> str:
|
||||
"""
|
||||
Downloads weights from huggingface and returns the path to the downloaded file
|
||||
|
||||
Given a huggingface repo url, folder, and optional filename, download the weights to the cache directory and return the path
|
||||
"""
|
||||
if filename is None:
|
||||
# select which weights to download. prefer fp16 safetensors
|
||||
data = parse_diffusers_repo_url(base_url)
|
||||
fs = HfFileSystem()
|
||||
filepaths = fs.ls(
|
||||
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
|
||||
)
|
||||
filepath = choose_huggingface_weights(filepaths, prefer_fp16=prefer_fp16)
|
||||
if not filepath:
|
||||
msg = f"Could not find any weights in {base_url}/{sub}"
|
||||
raise ValueError(msg)
|
||||
filename = filepath.split("/")[-1]
|
||||
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
|
||||
new_path = get_cached_url_path(url, category="weights")
|
||||
return new_path
|
||||
|
||||
|
||||
def choose_huggingface_weights(filenames: list[str], prefer_fp16=True) -> str | None:
|
||||
"""
|
||||
Chooses the best weights file from a list of filenames
|
||||
|
||||
Prefers safetensors format and fp16 dtype
|
||||
"""
|
||||
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
|
||||
# filter out any files that don't have a valid extension
|
||||
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
|
||||
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
|
||||
# sort by priority
|
||||
if prefer_fp16:
|
||||
filenames_and_extension.sort(
|
||||
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
|
||||
)
|
||||
else:
|
||||
filenames_and_extension.sort(
|
||||
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
|
||||
)
|
||||
if filenames_and_extension:
|
||||
return filenames_and_extension[0][0]
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cache_dir() -> str:
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
|
||||
if xdg_cache_home is None:
|
||||
user_home = os.getenv("HOME", None)
|
||||
if user_home:
|
||||
xdg_cache_home = os.path.join(user_home, ".cache")
|
||||
|
||||
if xdg_cache_home is not None:
|
||||
return os.path.join(xdg_cache_home, "imaginairy")
|
||||
|
||||
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
|
||||
|
||||
|
||||
class HuggingFaceAuthorizationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
hf_repo_url_pattern = re.compile(
|
||||
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
|
||||
)
|
||||
|
||||
|
||||
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
|
||||
match = hf_repo_url_pattern.match(url)
|
||||
return match.groupdict() if match else {}
|
||||
|
||||
|
||||
def is_diffusers_repo_url(url: str) -> bool:
|
||||
result = bool(parse_diffusers_repo_url(url))
|
||||
logger.debug(f"{url} is diffusers repo url: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def normalize_diffusers_repo_url(url: str) -> str:
|
||||
data = parse_diffusers_repo_url(url)
|
||||
ref = data["ref"] or "main"
|
||||
normalized_url = (
|
||||
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
|
||||
)
|
||||
return normalized_url
|
@ -2,19 +2,10 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import urllib.parse
|
||||
from functools import lru_cache, wraps
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
HfFileSystem,
|
||||
HfFolder,
|
||||
hf_hub_download as _hf_hub_download,
|
||||
try_to_load_from_cache,
|
||||
)
|
||||
from omegaconf import OmegaConf
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@ -27,6 +18,13 @@ from imaginairy.modules.refiners_sd import (
|
||||
StableDiffusion_XL_Inpainting,
|
||||
)
|
||||
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
|
||||
from imaginairy.utils.downloads import (
|
||||
HuggingFaceAuthorizationError,
|
||||
download_huggingface_weights,
|
||||
get_cached_url_path,
|
||||
is_diffusers_repo_url,
|
||||
normalize_diffusers_repo_url,
|
||||
)
|
||||
from imaginairy.utils.model_cache import memory_managed_model
|
||||
from imaginairy.utils.named_resolutions import normalize_image_size
|
||||
from imaginairy.utils.paths import PKG_ROOT
|
||||
@ -54,10 +52,6 @@ logger = logging.getLogger(__name__)
|
||||
MOST_RECENTLY_LOADED_MODEL = None
|
||||
|
||||
|
||||
class HuggingFaceAuthorizationError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def load_state_dict(weights_location, half_mode=False, device=None):
|
||||
if device is None:
|
||||
device = get_device()
|
||||
@ -118,13 +112,6 @@ def load_model_from_config(config, weights_location, half_mode=False):
|
||||
return model
|
||||
|
||||
|
||||
def add_controlnet(base_state_dict, controlnet_state_dict):
|
||||
"""Merges a base sd15 model with a controlnet model."""
|
||||
for key in controlnet_state_dict:
|
||||
base_state_dict[key] = controlnet_state_dict[key]
|
||||
return base_state_dict
|
||||
|
||||
|
||||
def get_diffusion_model(
|
||||
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
|
||||
config_path="configs/stable-diffusion-v1.yaml",
|
||||
@ -222,31 +209,6 @@ def get_diffusion_model_refiners(
|
||||
return sd
|
||||
|
||||
|
||||
hf_repo_url_pattern = re.compile(
|
||||
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
|
||||
)
|
||||
|
||||
|
||||
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
|
||||
match = hf_repo_url_pattern.match(url)
|
||||
return match.groupdict() if match else {}
|
||||
|
||||
|
||||
def is_diffusers_repo_url(url: str) -> bool:
|
||||
result = bool(parse_diffusers_repo_url(url))
|
||||
logger.debug(f"{url} is diffusers repo url: {result}")
|
||||
return result
|
||||
|
||||
|
||||
def normalize_diffusers_repo_url(url: str) -> str:
|
||||
data = parse_diffusers_repo_url(url)
|
||||
ref = data["ref"] or "main"
|
||||
normalized_url = (
|
||||
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
|
||||
)
|
||||
return normalized_url
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_diffusion_model_refiners(
|
||||
weights_location: str,
|
||||
@ -529,161 +491,6 @@ def get_current_diffusion_model():
|
||||
return MOST_RECENTLY_LOADED_MODEL
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
|
||||
if xdg_cache_home is None:
|
||||
user_home = os.getenv("HOME", None)
|
||||
if user_home:
|
||||
xdg_cache_home = os.path.join(user_home, ".cache")
|
||||
|
||||
if xdg_cache_home is not None:
|
||||
return os.path.join(xdg_cache_home, "imaginairy")
|
||||
|
||||
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
|
||||
|
||||
|
||||
def get_cached_url_path(url, category=None):
|
||||
"""
|
||||
Gets the contents of a url, but caches the response indefinitely.
|
||||
|
||||
While we attempt to use the cached_path from huggingface transformers, we fall back
|
||||
to our own implementation if the url does not provide an etag header, which `cached_path`
|
||||
requires. We also skip the `head` call that `cached_path` makes on every call if the file
|
||||
is already cached.
|
||||
"""
|
||||
|
||||
try:
|
||||
if url.startswith("https://huggingface.co"):
|
||||
return huggingface_cached_path(url)
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
filename = url.split("/")[-1]
|
||||
dest = get_cache_dir()
|
||||
if category:
|
||||
dest = os.path.join(dest, category)
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
|
||||
# Replace possibly illegal destination path characters
|
||||
safe_filename = re.sub('[*<>:"|?]', "_", filename)
|
||||
dest_path = os.path.join(dest, safe_filename)
|
||||
if os.path.exists(dest_path):
|
||||
return dest_path
|
||||
|
||||
# check if it's saved at previous path and rename it
|
||||
old_dest_path = os.path.join(dest, filename)
|
||||
if os.path.exists(old_dest_path):
|
||||
os.rename(old_dest_path, dest_path)
|
||||
return dest_path
|
||||
|
||||
r = requests.get(url)
|
||||
|
||||
with open(dest_path, "wb") as f:
|
||||
f.write(r.content)
|
||||
return dest_path
|
||||
|
||||
|
||||
def check_huggingface_url_authorized(url):
|
||||
if not url.startswith("https://huggingface.co/"):
|
||||
return None
|
||||
token = HfFolder.get_token()
|
||||
headers = {}
|
||||
if token is not None:
|
||||
headers["authorization"] = f"Bearer {token}"
|
||||
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
|
||||
if response.status_code == 401:
|
||||
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
|
||||
raise HuggingFaceAuthorizationError(msg)
|
||||
return None
|
||||
|
||||
|
||||
@wraps(_hf_hub_download)
|
||||
def hf_hub_download(*args, **kwargs):
|
||||
"""
|
||||
backwards compatible wrapper for huggingface's hf_hub_download.
|
||||
|
||||
they changed the argument name from `use_auth_token` to `token`
|
||||
"""
|
||||
|
||||
try:
|
||||
return _hf_hub_download(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
if "unexpected keyword argument 'token'" in str(e):
|
||||
kwargs["use_auth_token"] = kwargs.pop("token")
|
||||
return _hf_hub_download(*args, **kwargs)
|
||||
raise
|
||||
|
||||
|
||||
def huggingface_cached_path(url):
|
||||
# bypass all the HEAD calls done by the default `cached_path`
|
||||
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
|
||||
dest_path = try_to_load_from_cache(
|
||||
repo_id=repo, revision=commit_hash, filename=filepath
|
||||
)
|
||||
if not dest_path:
|
||||
check_huggingface_url_authorized(url)
|
||||
token = HfFolder.get_token()
|
||||
logger.info(f"Downloading {url} from huggingface")
|
||||
dest_path = hf_hub_download(
|
||||
repo_id=repo, revision=commit_hash, filename=filepath, token=token
|
||||
)
|
||||
# make a refs folder so caching works
|
||||
# work-around for
|
||||
# https://github.com/huggingface/huggingface_hub/pull/1306
|
||||
# https://github.com/brycedrennan/imaginAIry/issues/171
|
||||
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
|
||||
os.makedirs(refs_url, exist_ok=True)
|
||||
return dest_path
|
||||
|
||||
|
||||
def extract_huggingface_repo_commit_file_from_url(url):
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
path_components = parsed_url.path.strip("/").split("/")
|
||||
|
||||
repo = "/".join(path_components[0:2])
|
||||
assert path_components[2] == "resolve"
|
||||
commit_hash = path_components[3]
|
||||
filepath = "/".join(path_components[4:])
|
||||
|
||||
return repo, commit_hash, filepath
|
||||
|
||||
|
||||
def download_diffusers_weights(base_url, sub, filename=None, prefer_fp16=True):
|
||||
if filename is None:
|
||||
# select which weights to download. prefer fp16 safetensors
|
||||
data = parse_diffusers_repo_url(base_url)
|
||||
fs = HfFileSystem()
|
||||
filepaths = fs.ls(
|
||||
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
|
||||
)
|
||||
filepath = choose_diffusers_weights(filepaths, prefer_fp16=prefer_fp16)
|
||||
if not filepath:
|
||||
msg = f"Could not find any weights in {base_url}/{sub}"
|
||||
raise ValueError(msg)
|
||||
filename = filepath.split("/")[-1]
|
||||
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
|
||||
new_path = get_cached_url_path(url, category="weights")
|
||||
return new_path
|
||||
|
||||
|
||||
def choose_diffusers_weights(filenames, prefer_fp16=True):
|
||||
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
|
||||
# filter out any files that don't have a valid extension
|
||||
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
|
||||
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
|
||||
# sort by priority
|
||||
if prefer_fp16:
|
||||
filenames_and_extension.sort(
|
||||
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
|
||||
)
|
||||
else:
|
||||
filenames_and_extension.sort(
|
||||
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
|
||||
)
|
||||
if filenames_and_extension:
|
||||
return filenames_and_extension[0][0]
|
||||
return None
|
||||
|
||||
|
||||
def load_sd15_diffusers_weights(base_url: str, device=None):
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.weight_management.conversion import cast_weights
|
||||
@ -696,7 +503,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
|
||||
base_url = normalize_diffusers_repo_url(base_url)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae")
|
||||
vae_weights_path = download_huggingface_weights(base_url=base_url, sub="vae")
|
||||
vae_weights = open_weights(vae_weights_path, device=device)
|
||||
vae_weights = cast_weights(
|
||||
source_weights=vae_weights,
|
||||
@ -706,7 +513,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
|
||||
dest_format=FORMAT_NAMES.REFINERS,
|
||||
)
|
||||
|
||||
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
|
||||
unet_weights_path = download_huggingface_weights(base_url=base_url, sub="unet")
|
||||
unet_weights = open_weights(unet_weights_path, device=device)
|
||||
unet_weights = cast_weights(
|
||||
source_weights=unet_weights,
|
||||
@ -716,7 +523,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
|
||||
dest_format=FORMAT_NAMES.REFINERS,
|
||||
)
|
||||
|
||||
text_encoder_weights_path = download_diffusers_weights(
|
||||
text_encoder_weights_path = download_huggingface_weights(
|
||||
base_url=base_url, sub="text_encoder"
|
||||
)
|
||||
text_encoder_weights = open_weights(text_encoder_weights_path, device=device)
|
||||
@ -749,7 +556,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
||||
base_url = normalize_diffusers_repo_url(base_url)
|
||||
|
||||
translator = translators.diffusers_autoencoder_kl_to_refiners_translator()
|
||||
vae_weights_path = download_diffusers_weights(
|
||||
vae_weights_path = download_huggingface_weights(
|
||||
base_url=base_url, sub="vae", prefer_fp16=False
|
||||
)
|
||||
logger.debug(f"vae: {vae_weights_path}")
|
||||
@ -762,7 +569,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
||||
del vae_weights
|
||||
|
||||
translator = translators.diffusers_unet_sdxl_to_refiners_translator()
|
||||
unet_weights_path = download_diffusers_weights(
|
||||
unet_weights_path = download_huggingface_weights(
|
||||
base_url=base_url, sub="unet", prefer_fp16=True
|
||||
)
|
||||
logger.debug(f"unet: {unet_weights_path}")
|
||||
@ -777,10 +584,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
|
||||
unet.load_state_dict(unet_weights, assign=True)
|
||||
del unet_weights
|
||||
|
||||
text_encoder_1_path = download_diffusers_weights(
|
||||
text_encoder_1_path = download_huggingface_weights(
|
||||
base_url=base_url, sub="text_encoder"
|
||||
)
|
||||
text_encoder_2_path = download_diffusers_weights(
|
||||
text_encoder_2_path = download_huggingface_weights(
|
||||
base_url=base_url, sub="text_encoder_2"
|
||||
)
|
||||
logger.debug(f"text encoder 1: {text_encoder_1_path}")
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
import safetensors
|
||||
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.utils.model_manager import (
|
||||
get_cached_url_path,
|
||||
open_weights,
|
||||
resolve_model_weights_config,
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.utils.downloads import get_cached_url_path
|
||||
from imaginairy.utils.paths import PKG_ROOT
|
||||
|
||||
sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt"
|
||||
|
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.utils.downloads import parse_diffusers_repo_url
|
||||
from imaginairy.utils.model_manager import (
|
||||
parse_diffusers_repo_url,
|
||||
resolve_model_weights_config,
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user