Compare commits

...

2 Commits

@ -94,6 +94,8 @@ def generate_video(
output_fps = default(output_fps, fps_id)
model_name = model_name.lower().replace("_", "-")
video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None)
if video_model_config is None:
msg = f"Version {model_name} does not exist."

@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
@click.option(
"--model",
default="svd",
help="Model to use. One of: svd, svd_xt, svd_image_decoder, svd_xt_image_decoder",
help="Model to use. One of: svd, svd-xt, svd-image-decoder, svd-xt-image-decoder",
)
@click.option(
"--fps", default=6, type=int, help="FPS for the AI to target when generating video"

@ -15,6 +15,17 @@ from huggingface_hub import (
logger = logging.getLogger(__name__)
def resolve_path_or_url(path_or_url: str, category=None) -> str:
"""
Resolves a path or url to a local absolute file path
If the path_or_url is a url, it will be downloaded to the cache directory and the path to the downloaded file will be returned.
"""
if path_or_url.startswith(("https://", "http://")):
return get_cached_url_path(url=path_or_url, category=category)
return os.path.abspath(path_or_url)
def get_cached_url_path(url: str, category=None) -> str:
"""
Gets the contents of a url, but caches the response indefinitely.

@ -24,6 +24,7 @@ from imaginairy.utils.downloads import (
get_cached_url_path,
is_diffusers_repo_url,
normalize_diffusers_repo_url,
resolve_path_or_url,
)
from imaginairy.utils.model_cache import memory_managed_model
from imaginairy.utils.named_resolutions import normalize_image_size
@ -793,7 +794,7 @@ def load_stable_diffusion_compvis_weights(weights_url):
def load_sdxl_compvis_weights(url):
from safetensors import safe_open
weights_path = get_cached_url_path(url)
weights_path = resolve_path_or_url(url)
state_dict = {}
unet_state_dict = {}
vae_state_dict = {}

@ -1,7 +1,7 @@
import pytest
from imaginairy import config
from imaginairy.utils.downloads import parse_diffusers_repo_url
from imaginairy.utils.downloads import parse_diffusers_repo_url, resolve_path_or_url
from imaginairy.utils.model_manager import (
resolve_model_weights_config,
)
@ -59,3 +59,8 @@ hf_urls_cases = [
def test_parse_diffusers_repo_url(url, expected):
result = parse_diffusers_repo_url(url)
assert result == expected
def test_resolve_sdxl_path_or_url():
a = "/foo/bar.safetensors"
assert resolve_path_or_url(a) == a

Loading…
Cancel
Save