diff --git a/imaginairy/utils/downloads.py b/imaginairy/utils/downloads.py index 927a007..3469c5c 100644 --- a/imaginairy/utils/downloads.py +++ b/imaginairy/utils/downloads.py @@ -15,6 +15,17 @@ from huggingface_hub import ( logger = logging.getLogger(__name__) +def resolve_path_or_url(path_or_url: str, category=None) -> str: + """ + Resolves a path or url to a local absolute file path + + If the path_or_url is a url, it will be downloaded to the cache directory and the path to the downloaded file will be returned. + """ + if path_or_url.startswith(("https://", "http://")): + return get_cached_url_path(url=path_or_url, category=category) + return os.path.abspath(path_or_url) + + def get_cached_url_path(url: str, category=None) -> str: """ Gets the contents of a url, but caches the response indefinitely. diff --git a/imaginairy/utils/model_manager.py b/imaginairy/utils/model_manager.py index 0b14fb4..dea9bd5 100644 --- a/imaginairy/utils/model_manager.py +++ b/imaginairy/utils/model_manager.py @@ -24,6 +24,7 @@ from imaginairy.utils.downloads import ( get_cached_url_path, is_diffusers_repo_url, normalize_diffusers_repo_url, + resolve_path_or_url, ) from imaginairy.utils.model_cache import memory_managed_model from imaginairy.utils.named_resolutions import normalize_image_size @@ -793,7 +794,7 @@ def load_stable_diffusion_compvis_weights(weights_url): def load_sdxl_compvis_weights(url): from safetensors import safe_open - weights_path = get_cached_url_path(url) + weights_path = resolve_path_or_url(url) state_dict = {} unet_state_dict = {} vae_state_dict = {} diff --git a/tests/test_utils/test_model_manager.py b/tests/test_utils/test_model_manager.py index 6052917..d0e64fb 100644 --- a/tests/test_utils/test_model_manager.py +++ b/tests/test_utils/test_model_manager.py @@ -1,7 +1,7 @@ import pytest from imaginairy import config -from imaginairy.utils.downloads import parse_diffusers_repo_url +from imaginairy.utils.downloads import parse_diffusers_repo_url, resolve_path_or_url from imaginairy.utils.model_manager import ( resolve_model_weights_config, ) @@ -59,3 +59,8 @@ hf_urls_cases = [ def test_parse_diffusers_repo_url(url, expected): result = parse_diffusers_repo_url(url) assert result == expected + + +def test_resolve_sdxl_path_or_url(): + a = "/foo/bar.safetensors" + assert resolve_path_or_url(a) == a