refactor: move download related functions to separate module (#453)

+ renames and typehints
This commit is contained in:
Bryce Drennan 2024-01-14 16:50:17 -08:00 committed by GitHub
parent 502ffbdc63
commit 601a112dc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 270 additions and 241 deletions

View File

@ -71,6 +71,27 @@ Options:
### Whats New
[See full Changelog here](./docs/changelog.md)
**14.1.0**
- 🎉 feature: make video generation smooth by adding frame interpolation
- feature: SDXL weights in the compvis format can now be used
- feature: allow video generation at any size specified by user
- feature: video generations output in "bounce" format
- feature: choose video output format: mp4, webp, or gif
- feature: fix random seed handling in video generation
- docs: auto-publish docs on push to master
- build: remove imageio dependency
- build: vendorize facexlib so we don't install its unneeded dependencies
**14.0.4**
- docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/
- build: remove fairscale dependency
- fix: video generation was broken
**14.0.3**
- fix: several critical bugs with package
- tests: add a wheel smoketest to detect these issues in the future
**14.0.0**
- 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models)
- add `--videogen` to any image generation to create a short video from the generated image
@ -86,23 +107,7 @@ cutting edge features (SDXL, image prompts, etc) which will be added to imaginai
For example `--size 720p --seed 1` and `--size 1080p --seed 1` will produce the same image for SD15
- 🎉 feature: loading diffusers based models now supported. Example `--model https://huggingface.co/ainz/diseny-pixar --model-architecture sd15`
- 🎉 feature: qrcode controlnet!
- feature: generate word images automatically. great for use with qrcode controlnet: `imagine "flowers" --gif --size hd --control-mode qrcode --control-image "textimg='JOY' font_color=white background_color=gray" -r 10`
- feature: opendalle 1.1 added. `--model opendalle` to use it
- feature: added `--size` parameter for more intuitive sizing (e.g. 512, 256x256, 4k, uhd, FHD, VGA, etc)
- feature: detect if wrong torch version is installed and provide instructions on how to install proper version
- feature: better logging output: color, error handling
- feature: support for pytorch 2.0
- feature: command line output significantly cleaned up and easier to read
- feature: adds --composition-strength parameter to cli (#416)
- performance: lower memory usage for upscaling
- performance: lower memory usage at startup
- performance: add sliced attention to several models (lowers memory use)
- fix: simpler memory management that avoids some of the previous bugs
- deprecated: support for python 3.8, 3.9
- deprecated: support for torch 1.13
- deprecated: support for Stable Diffusion versions 1.4, 2.0, and 2.1
- deprecated: image training
- broken: samplers other than ddim
### Run API server and StableStudio web interface (alpha)
Generate images via API or web interface. Much smaller featureset compared to the command line tool.

View File

@ -13,11 +13,17 @@
- ✅ add type checker
- ✅ add interface for loading diffusers weights
- ✅ SDXL support
- sdxl inpainting
- sdxl inpainting
- t2i adapters
- image prompts
- embedding inputs
- save complete metadata to image
- recreate image from metadata
- auto-incoporate https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
- only output the main image unless some flag is set
- allow selection of output video format
- ✅ allow selection of output video format
- test on python 3.11
- allow specification of filename format
- chain multiple operations together imggen => videogen
- https://github.com/pallets/click/tree/main/examples/imagepipe
@ -57,7 +63,7 @@
- ✅ set up ci (test/lint/format)
- ✅ unified pipeline (txt2img & img2img combined)
- ✅ setup parallel testing
- add docs
- add docs
- 🚫 remove yaml config
- 🚫 delete more unused code
- faster latent logging https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/9

View File

@ -28,7 +28,7 @@ from imaginairy.utils import (
platform_appropriate_autocast,
)
from imaginairy.utils.animations import make_bounce_animation
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT

View File

@ -9,7 +9,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.vendored.blip.blip import BLIP_Decoder, load_checkpoint
device = get_device()

View File

@ -8,7 +8,7 @@ import torch
from PIL import Image
from torchvision.transforms.functional import normalize
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.vendored.basicsr.img_util import img2tensor, tensor2img
from imaginairy.vendored.codeformer.codeformer_arch import CodeFormer
from imaginairy.vendored.facexlib.utils.face_restoration_helper import FaceRestoreHelper

View File

@ -5,8 +5,8 @@ import torch
from PIL import Image
from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.vendored.basicsr.rrdbnet_arch import RRDBNet
from imaginairy.vendored.realesrgan import RealESRGANer

View File

@ -8,8 +8,8 @@ import torch.nn.functional as F
from torch import nn
from imaginairy.utils import get_device, platform_appropriate_autocast
from imaginairy.utils.downloads import hf_hub_download
from imaginairy.utils.log_utils import log_latent
from imaginairy.utils.model_manager import hf_hub_download
from imaginairy.vendored import k_diffusion as K
from imaginairy.vendored.k_diffusion import layers
from imaginairy.vendored.k_diffusion.models.image_v1 import ImageDenoiserModelV1

View File

@ -15,7 +15,7 @@ from torch.nn import functional as F
from tqdm import tqdm
from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
from .msssim import ssim_matlab
from .RIFE_HDv3 import Model

View File

@ -7,7 +7,7 @@ import numpy as np
import torch
from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
class Network(torch.nn.Module):

View File

@ -11,8 +11,8 @@ from scipy.ndimage.filters import gaussian_filter
from torch import nn
from imaginairy.utils import get_device
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.img_utils import torch_image_to_openvcv_img
from imaginairy.utils.model_manager import get_cached_url_path
def pad_right_down_corner(img, stride, padValue):

View File

@ -2,7 +2,7 @@
import torch
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
class BaseModel(torch.nn.Module):

View File

@ -0,0 +1,211 @@
import logging
import os
import re
import urllib.parse
from functools import lru_cache, wraps
import requests
from huggingface_hub import (
HfFileSystem,
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
logger = logging.getLogger(__name__)
def get_cached_url_path(url: str, category=None) -> str:
"""
Gets the contents of a url, but caches the response indefinitely.
While we attempt to use the cached_path from huggingface transformers, we fall back
to our own implementation if the url does not provide an etag header, which `cached_path`
requires. We also skip the `head` call that `cached_path` makes on every call if the file
is already cached.
"""
if url.startswith("https://huggingface.co"):
try:
return huggingface_cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
if category:
dest = os.path.join(dest, category)
os.makedirs(dest, exist_ok=True)
# Replace possibly illegal destination path characters
safe_filename = re.sub('[*<>:"|?]', "_", filename)
dest_path = os.path.join(dest, safe_filename)
if os.path.exists(dest_path):
return dest_path
# check if it's saved at previous path and rename it
old_dest_path = os.path.join(dest, filename)
if os.path.exists(old_dest_path):
os.rename(old_dest_path, dest_path)
return dest_path
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def check_huggingface_url_authorized(url: str) -> None:
if not url.startswith("https://huggingface.co/"):
return None
token = HfFolder.get_token()
headers = {}
if token is not None:
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
raise HuggingFaceAuthorizationError(msg)
return None
@wraps(_hf_hub_download)
def hf_hub_download(*args, **kwargs):
"""
backwards compatible wrapper for huggingface's hf_hub_download.
they changed the argument name from `use_auth_token` to `token`
"""
try:
return _hf_hub_download(*args, **kwargs)
except TypeError as e:
if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
raise
def huggingface_cached_path(url: str) -> str:
# bypass all the HEAD calls done by the default `cached_path`
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
dest_path = try_to_load_from_cache(
repo_id=repo, revision=commit_hash, filename=filepath
)
if not dest_path:
check_huggingface_url_authorized(url)
token = HfFolder.get_token()
logger.info(f"Downloading {url} from huggingface")
dest_path = hf_hub_download(
repo_id=repo, revision=commit_hash, filename=filepath, token=token
)
# make a refs folder so caching works
# work-around for
# https://github.com/huggingface/huggingface_hub/pull/1306
# https://github.com/brycedrennan/imaginAIry/issues/171
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
os.makedirs(refs_url, exist_ok=True)
return dest_path
def extract_huggingface_repo_commit_file_from_url(url):
parsed_url = urllib.parse.urlparse(url)
path_components = parsed_url.path.strip("/").split("/")
repo = "/".join(path_components[0:2])
assert path_components[2] == "resolve"
commit_hash = path_components[3]
filepath = "/".join(path_components[4:])
return repo, commit_hash, filepath
def download_huggingface_weights(
base_url: str, sub: str, filename=None, prefer_fp16=True
) -> str:
"""
Downloads weights from huggingface and returns the path to the downloaded file
Given a huggingface repo url, folder, and optional filename, download the weights to the cache directory and return the path
"""
if filename is None:
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url(base_url)
fs = HfFileSystem()
filepaths = fs.ls(
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
)
filepath = choose_huggingface_weights(filepaths, prefer_fp16=prefer_fp16)
if not filepath:
msg = f"Could not find any weights in {base_url}/{sub}"
raise ValueError(msg)
filename = filepath.split("/")[-1]
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
new_path = get_cached_url_path(url, category="weights")
return new_path
def choose_huggingface_weights(filenames: list[str], prefer_fp16=True) -> str | None:
"""
Chooses the best weights file from a list of filenames
Prefers safetensors format and fp16 dtype
"""
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
# filter out any files that don't have a valid extension
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
# sort by priority
if prefer_fp16:
filenames_and_extension.sort(
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
)
else:
filenames_and_extension.sort(
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
)
if filenames_and_extension:
return filenames_and_extension[0][0]
return None
@lru_cache
def get_cache_dir() -> str:
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy")
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
class HuggingFaceAuthorizationError(RuntimeError):
pass
hf_repo_url_pattern = re.compile(
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
)
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
match = hf_repo_url_pattern.match(url)
return match.groupdict() if match else {}
def is_diffusers_repo_url(url: str) -> bool:
result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str:
data = parse_diffusers_repo_url(url)
ref = data["ref"] or "main"
normalized_url = (
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
)
return normalized_url

View File

@ -2,19 +2,10 @@
import logging
import os
import re
import sys
import urllib.parse
from functools import lru_cache, wraps
from functools import lru_cache
import requests
import torch
from huggingface_hub import (
HfFileSystem,
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
)
from omegaconf import OmegaConf
from safetensors.torch import load_file
@ -27,6 +18,13 @@ from imaginairy.modules.refiners_sd import (
StableDiffusion_XL_Inpainting,
)
from imaginairy.utils import clear_gpu_cache, get_device, instantiate_from_config
from imaginairy.utils.downloads import (
HuggingFaceAuthorizationError,
download_huggingface_weights,
get_cached_url_path,
is_diffusers_repo_url,
normalize_diffusers_repo_url,
)
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
from imaginairy.utils.paths import PKG_ROOT
@ -54,10 +52,6 @@ logger = logging.getLogger(__name__)
MOST_RECENTLY_LOADED_MODEL = None
class HuggingFaceAuthorizationError(RuntimeError):
pass
def load_state_dict(weights_location, half_mode=False, device=None):
if device is None:
device = get_device()
@ -118,13 +112,6 @@ def load_model_from_config(config, weights_location, half_mode=False):
return model
def add_controlnet(base_state_dict, controlnet_state_dict):
"""Merges a base sd15 model with a controlnet model."""
for key in controlnet_state_dict:
base_state_dict[key] = controlnet_state_dict[key]
return base_state_dict
def get_diffusion_model(
weights_location=iconfig.DEFAULT_MODEL_WEIGHTS,
config_path="configs/stable-diffusion-v1.yaml",
@ -222,31 +209,6 @@ def get_diffusion_model_refiners(
return sd
hf_repo_url_pattern = re.compile(
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
)
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
match = hf_repo_url_pattern.match(url)
return match.groupdict() if match else {}
def is_diffusers_repo_url(url: str) -> bool:
result = bool(parse_diffusers_repo_url(url))
logger.debug(f"{url} is diffusers repo url: {result}")
return result
def normalize_diffusers_repo_url(url: str) -> str:
data = parse_diffusers_repo_url(url)
ref = data["ref"] or "main"
normalized_url = (
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
)
return normalized_url
@lru_cache(maxsize=1)
def _get_diffusion_model_refiners(
weights_location: str,
@ -529,161 +491,6 @@ def get_current_diffusion_model():
return MOST_RECENTLY_LOADED_MODEL
def get_cache_dir():
xdg_cache_home = os.getenv("XDG_CACHE_HOME", None)
if xdg_cache_home is None:
user_home = os.getenv("HOME", None)
if user_home:
xdg_cache_home = os.path.join(user_home, ".cache")
if xdg_cache_home is not None:
return os.path.join(xdg_cache_home, "imaginairy")
return os.path.join(os.path.dirname(__file__), ".cached-aimg")
def get_cached_url_path(url, category=None):
"""
Gets the contents of a url, but caches the response indefinitely.
While we attempt to use the cached_path from huggingface transformers, we fall back
to our own implementation if the url does not provide an etag header, which `cached_path`
requires. We also skip the `head` call that `cached_path` makes on every call if the file
is already cached.
"""
try:
if url.startswith("https://huggingface.co"):
return huggingface_cached_path(url)
except (OSError, ValueError):
pass
filename = url.split("/")[-1]
dest = get_cache_dir()
if category:
dest = os.path.join(dest, category)
os.makedirs(dest, exist_ok=True)
# Replace possibly illegal destination path characters
safe_filename = re.sub('[*<>:"|?]', "_", filename)
dest_path = os.path.join(dest, safe_filename)
if os.path.exists(dest_path):
return dest_path
# check if it's saved at previous path and rename it
old_dest_path = os.path.join(dest, filename)
if os.path.exists(old_dest_path):
os.rename(old_dest_path, dest_path)
return dest_path
r = requests.get(url)
with open(dest_path, "wb") as f:
f.write(r.content)
return dest_path
def check_huggingface_url_authorized(url):
if not url.startswith("https://huggingface.co/"):
return None
token = HfFolder.get_token()
headers = {}
if token is not None:
headers["authorization"] = f"Bearer {token}"
response = requests.head(url, allow_redirects=True, headers=headers, timeout=5)
if response.status_code == 401:
msg = "Unauthorized access to HuggingFace model. This model requires a huggingface token. Please login to HuggingFace or set HUGGING_FACE_HUB_TOKEN to your User Access Token. See https://huggingface.co/docs/huggingface_hub/quick-start#login for more information"
raise HuggingFaceAuthorizationError(msg)
return None
@wraps(_hf_hub_download)
def hf_hub_download(*args, **kwargs):
"""
backwards compatible wrapper for huggingface's hf_hub_download.
they changed the argument name from `use_auth_token` to `token`
"""
try:
return _hf_hub_download(*args, **kwargs)
except TypeError as e:
if "unexpected keyword argument 'token'" in str(e):
kwargs["use_auth_token"] = kwargs.pop("token")
return _hf_hub_download(*args, **kwargs)
raise
def huggingface_cached_path(url):
# bypass all the HEAD calls done by the default `cached_path`
repo, commit_hash, filepath = extract_huggingface_repo_commit_file_from_url(url)
dest_path = try_to_load_from_cache(
repo_id=repo, revision=commit_hash, filename=filepath
)
if not dest_path:
check_huggingface_url_authorized(url)
token = HfFolder.get_token()
logger.info(f"Downloading {url} from huggingface")
dest_path = hf_hub_download(
repo_id=repo, revision=commit_hash, filename=filepath, token=token
)
# make a refs folder so caching works
# work-around for
# https://github.com/huggingface/huggingface_hub/pull/1306
# https://github.com/brycedrennan/imaginAIry/issues/171
refs_url = dest_path[: dest_path.index("/snapshots/")] + "/refs/"
os.makedirs(refs_url, exist_ok=True)
return dest_path
def extract_huggingface_repo_commit_file_from_url(url):
parsed_url = urllib.parse.urlparse(url)
path_components = parsed_url.path.strip("/").split("/")
repo = "/".join(path_components[0:2])
assert path_components[2] == "resolve"
commit_hash = path_components[3]
filepath = "/".join(path_components[4:])
return repo, commit_hash, filepath
def download_diffusers_weights(base_url, sub, filename=None, prefer_fp16=True):
if filename is None:
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url(base_url)
fs = HfFileSystem()
filepaths = fs.ls(
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
)
filepath = choose_diffusers_weights(filepaths, prefer_fp16=prefer_fp16)
if not filepath:
msg = f"Could not find any weights in {base_url}/{sub}"
raise ValueError(msg)
filename = filepath.split("/")[-1]
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
new_path = get_cached_url_path(url, category="weights")
return new_path
def choose_diffusers_weights(filenames, prefer_fp16=True):
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
# filter out any files that don't have a valid extension
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
# sort by priority
if prefer_fp16:
filenames_and_extension.sort(
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
)
else:
filenames_and_extension.sort(
key=lambda x: ("fp16" in x[0], extension_priority.index(x[1]))
)
if filenames_and_extension:
return filenames_and_extension[0][0]
return None
def load_sd15_diffusers_weights(base_url: str, device=None):
from imaginairy.utils import get_device
from imaginairy.weight_management.conversion import cast_weights
@ -696,7 +503,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
base_url = normalize_diffusers_repo_url(base_url)
if device is None:
device = get_device()
vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae")
vae_weights_path = download_huggingface_weights(base_url=base_url, sub="vae")
vae_weights = open_weights(vae_weights_path, device=device)
vae_weights = cast_weights(
source_weights=vae_weights,
@ -706,7 +513,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
dest_format=FORMAT_NAMES.REFINERS,
)
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
unet_weights_path = download_huggingface_weights(base_url=base_url, sub="unet")
unet_weights = open_weights(unet_weights_path, device=device)
unet_weights = cast_weights(
source_weights=unet_weights,
@ -716,7 +523,7 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
dest_format=FORMAT_NAMES.REFINERS,
)
text_encoder_weights_path = download_diffusers_weights(
text_encoder_weights_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder"
)
text_encoder_weights = open_weights(text_encoder_weights_path, device=device)
@ -749,7 +556,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
base_url = normalize_diffusers_repo_url(base_url)
translator = translators.diffusers_autoencoder_kl_to_refiners_translator()
vae_weights_path = download_diffusers_weights(
vae_weights_path = download_huggingface_weights(
base_url=base_url, sub="vae", prefer_fp16=False
)
logger.debug(f"vae: {vae_weights_path}")
@ -762,7 +569,7 @@ def load_sdxl_pipeline_from_diffusers_weights(
del vae_weights
translator = translators.diffusers_unet_sdxl_to_refiners_translator()
unet_weights_path = download_diffusers_weights(
unet_weights_path = download_huggingface_weights(
base_url=base_url, sub="unet", prefer_fp16=True
)
logger.debug(f"unet: {unet_weights_path}")
@ -777,10 +584,10 @@ def load_sdxl_pipeline_from_diffusers_weights(
unet.load_state_dict(unet_weights, assign=True)
del unet_weights
text_encoder_1_path = download_diffusers_weights(
text_encoder_1_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder"
)
text_encoder_2_path = download_diffusers_weights(
text_encoder_2_path = download_huggingface_weights(
base_url=base_url, sub="text_encoder_2"
)
logger.debug(f"text encoder 1: {text_encoder_1_path}")

View File

@ -9,7 +9,7 @@ import numpy as np
import torch
from torch.nn import functional as F
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

View File

@ -2,8 +2,8 @@
import safetensors
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.model_manager import (
get_cached_url_path,
open_weights,
resolve_model_weights_config,
)

View File

@ -3,7 +3,7 @@ import os
import torch
from safetensors.torch import load_file, save_file
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.paths import PKG_ROOT
sd15_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt"

View File

@ -1,8 +1,8 @@
import pytest
from imaginairy import config
from imaginairy.utils.downloads import parse_diffusers_repo_url
from imaginairy.utils.model_manager import (
parse_diffusers_repo_url,
resolve_model_weights_config,
)