fix: allow referencing local paths for sdxl model weights

Addresses https://github.com/brycedrennan/imaginAIry/issues/484
pull/464/head
Bryce 4 weeks ago committed by Bryce Drennan
parent 3c1c695f76
commit 3a9a3974ce

@ -15,6 +15,17 @@ from huggingface_hub import (
logger = logging.getLogger(__name__)
def resolve_path_or_url(path_or_url: str, category=None) -> str:
"""
Resolves a path or url to a local absolute file path
If the path_or_url is a url, it will be downloaded to the cache directory and the path to the downloaded file will be returned.
"""
if path_or_url.startswith(("https://", "http://")):
return get_cached_url_path(url=path_or_url, category=category)
return os.path.abspath(path_or_url)
def get_cached_url_path(url: str, category=None) -> str:
"""
Gets the contents of a url, but caches the response indefinitely.

@ -24,6 +24,7 @@ from imaginairy.utils.downloads import (
get_cached_url_path,
is_diffusers_repo_url,
normalize_diffusers_repo_url,
resolve_path_or_url,
)
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
@ -793,7 +794,7 @@ def load_stable_diffusion_compvis_weights(weights_url):
def load_sdxl_compvis_weights(url):
from safetensors import safe_open
weights_path = get_cached_url_path(url)
weights_path = resolve_path_or_url(url)
state_dict = {}
unet_state_dict = {}
vae_state_dict = {}

@ -1,7 +1,7 @@
import pytest
from imaginairy import config
from imaginairy.utils.downloads import parse_diffusers_repo_url
from imaginairy.utils.downloads import parse_diffusers_repo_url, resolve_path_or_url
from imaginairy.utils.model_manager import (
resolve_model_weights_config,
)
@ -59,3 +59,8 @@ hf_urls_cases = [
def test_parse_diffusers_repo_url(url, expected):
result = parse_diffusers_repo_url(url)
assert result == expected
def test_resolve_sdxl_path_or_url():
a = "/foo/bar.safetensors"
assert resolve_path_or_url(a) == a

Loading…
Cancel
Save