feature: support loading diffusers folder/based models from huggingface (#427)

This commit is contained in:
Bryce Drennan 2023-12-21 14:24:35 -08:00 committed by GitHub
parent 50e796a3b7
commit a2c38b3ec0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 356 additions and 45 deletions

View File

@ -150,6 +150,20 @@ MODEL_WEIGHT_CONFIGS = [
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
defaults={"negative_prompt": "poor quality"},
),
ModelWeightsConfig(
name="Modern Disney",
aliases=["modern-disney", "modi", "modi15", "modern-disney-15"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/nitrosocke/mo-di-diffusion/tree/e3106d24aa8c37bf856257daea2ae789eabc4d70/",
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
),
ModelWeightsConfig(
name="Modern Disney",
aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"],
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/",
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
),
# Video Weights
ModelWeightsConfig(
name="Stable Video Diffusion",

View File

@ -10,6 +10,7 @@ from functools import lru_cache, wraps
import requests
import torch
from huggingface_hub import (
HfFileSystem,
HfFolder,
hf_hub_download as _hf_hub_download,
try_to_load_from_cache,
@ -239,6 +240,29 @@ def get_diffusion_model_refiners(
)
hf_repo_url_pattern = re.compile(
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
)
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
match = hf_repo_url_pattern.match(url)
return match.groupdict() if match else {}
def is_diffusers_repo_url(url: str) -> bool:
return bool(parse_diffusers_repo_url(url))
def normalize_diffusers_repo_url(url: str) -> str:
data = parse_diffusers_repo_url(url)
ref = data["ref"] or "main"
normalized_url = (
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
)
return normalized_url
@lru_cache(maxsize=1)
def _get_diffusion_model_refiners(
weights_location: str,
@ -260,12 +284,18 @@ def _get_diffusion_model_refiners(
global MOST_RECENTLY_LOADED_MODEL
device = device or get_device()
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location)
if is_diffusers_repo_url(weights_location):
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_diffusers_weights(weights_location)
else:
(
vae_weights,
unet_weights,
text_encoder_weights,
) = load_stable_diffusion_compvis_weights(weights_location)
StableDiffusionCls: type[LatentDiffusionModel]
if for_inpainting:
@ -529,15 +559,40 @@ def extract_huggingface_repo_commit_file_from_url(url):
return repo, commit_hash, filepath
def download_diffusers_weights(repo, sub, filename):
from imaginairy.utils.model_manager import get_cached_url_path
def download_diffusers_weights(base_url, sub, filename=None):
if filename is None:
# select which weights to download. prefer fp16 safetensors
data = parse_diffusers_repo_url(base_url)
fs = HfFileSystem()
filepaths = fs.ls(
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
)
filepath = choose_diffusers_weights(filepaths)
if not filepath:
msg = f"Could not find any weights in {base_url}/{sub}"
raise ValueError(msg)
filename = filepath.split("/")[-1]
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
new_path = get_cached_url_path(url, category="weights")
return new_path
url = f"https://huggingface.co/{repo}/resolve/main/{sub}/{filename}"
return get_cached_url_path(url, category="weights")
def choose_diffusers_weights(filenames):
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
# filter out any files that don't have a valid extension
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
# sort by priority
filenames_and_extension.sort(
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
)
if filenames_and_extension:
return filenames_and_extension[0][0]
return None
@lru_cache
def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
def load_stable_diffusion_diffusers_weights(base_url: str, device=None):
from imaginairy.utils import get_device
from imaginairy.weight_management.conversion import cast_weights
from imaginairy.weight_management.utils import (
@ -546,11 +601,10 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
MODEL_NAMES,
)
base_url = normalize_diffusers_repo_url(base_url)
if device is None:
device = get_device()
vae_weights_path = download_diffusers_weights(
repo=diffusers_repo, sub="vae", filename="diffusion_pytorch_model.safetensors"
)
vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae")
vae_weights = open_weights(vae_weights_path, device=device)
vae_weights = cast_weights(
source_weights=vae_weights,
@ -560,9 +614,7 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
dest_format=FORMAT_NAMES.REFINERS,
)
unet_weights_path = download_diffusers_weights(
repo=diffusers_repo, sub="unet", filename="diffusion_pytorch_model.safetensors"
)
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
unet_weights = open_weights(unet_weights_path, device=device)
unet_weights = cast_weights(
source_weights=unet_weights,
@ -573,7 +625,7 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
)
text_encoder_weights_path = download_diffusers_weights(
repo=diffusers_repo, sub="text_encoder", filename="model.safetensors"
base_url=base_url, sub="text_encoder"
)
text_encoder_weights = open_weights(text_encoder_weights_path, device=device)
text_encoder_weights = cast_weights(
@ -613,7 +665,6 @@ def open_weights(filepath, device=None):
def load_stable_diffusion_compvis_weights(weights_url):
from imaginairy.utils import get_device
from imaginairy.utils.model_manager import get_cached_url_path
from imaginairy.weight_management.conversion import cast_weights
from imaginairy.weight_management.utils import (
COMPONENT_NAMES,

View File

@ -0,0 +1,200 @@
{
"ignorable_prefixes": [
"text_model.embeddings"
],
"mapping": {
"text_model.embeddings.token_embedding": "Sum.TokenEncoder",
"text_model.embeddings.position_embedding": "Sum.PositionalEncoder.Embedding",
"text_model.encoder.layers.0.layer_norm1": "TransformerLayer_1.Residual_1.LayerNorm",
"text_model.encoder.layers.0.layer_norm2": "TransformerLayer_1.Residual_2.LayerNorm",
"text_model.encoder.layers.1.layer_norm1": "TransformerLayer_2.Residual_1.LayerNorm",
"text_model.encoder.layers.1.layer_norm2": "TransformerLayer_2.Residual_2.LayerNorm",
"text_model.encoder.layers.2.layer_norm1": "TransformerLayer_3.Residual_1.LayerNorm",
"text_model.encoder.layers.2.layer_norm2": "TransformerLayer_3.Residual_2.LayerNorm",
"text_model.encoder.layers.3.layer_norm1": "TransformerLayer_4.Residual_1.LayerNorm",
"text_model.encoder.layers.3.layer_norm2": "TransformerLayer_4.Residual_2.LayerNorm",
"text_model.encoder.layers.4.layer_norm1": "TransformerLayer_5.Residual_1.LayerNorm",
"text_model.encoder.layers.4.layer_norm2": "TransformerLayer_5.Residual_2.LayerNorm",
"text_model.encoder.layers.5.layer_norm1": "TransformerLayer_6.Residual_1.LayerNorm",
"text_model.encoder.layers.5.layer_norm2": "TransformerLayer_6.Residual_2.LayerNorm",
"text_model.encoder.layers.6.layer_norm1": "TransformerLayer_7.Residual_1.LayerNorm",
"text_model.encoder.layers.6.layer_norm2": "TransformerLayer_7.Residual_2.LayerNorm",
"text_model.encoder.layers.7.layer_norm1": "TransformerLayer_8.Residual_1.LayerNorm",
"text_model.encoder.layers.7.layer_norm2": "TransformerLayer_8.Residual_2.LayerNorm",
"text_model.encoder.layers.8.layer_norm1": "TransformerLayer_9.Residual_1.LayerNorm",
"text_model.encoder.layers.8.layer_norm2": "TransformerLayer_9.Residual_2.LayerNorm",
"text_model.encoder.layers.9.layer_norm1": "TransformerLayer_10.Residual_1.LayerNorm",
"text_model.encoder.layers.9.layer_norm2": "TransformerLayer_10.Residual_2.LayerNorm",
"text_model.encoder.layers.10.layer_norm1": "TransformerLayer_11.Residual_1.LayerNorm",
"text_model.encoder.layers.10.layer_norm2": "TransformerLayer_11.Residual_2.LayerNorm",
"text_model.encoder.layers.11.layer_norm1": "TransformerLayer_12.Residual_1.LayerNorm",
"text_model.encoder.layers.11.layer_norm2": "TransformerLayer_12.Residual_2.LayerNorm",
"text_model.encoder.layers.12.layer_norm1": "TransformerLayer_13.Residual_1.LayerNorm",
"text_model.encoder.layers.12.layer_norm2": "TransformerLayer_13.Residual_2.LayerNorm",
"text_model.encoder.layers.13.layer_norm1": "TransformerLayer_14.Residual_1.LayerNorm",
"text_model.encoder.layers.13.layer_norm2": "TransformerLayer_14.Residual_2.LayerNorm",
"text_model.encoder.layers.14.layer_norm1": "TransformerLayer_15.Residual_1.LayerNorm",
"text_model.encoder.layers.14.layer_norm2": "TransformerLayer_15.Residual_2.LayerNorm",
"text_model.encoder.layers.15.layer_norm1": "TransformerLayer_16.Residual_1.LayerNorm",
"text_model.encoder.layers.15.layer_norm2": "TransformerLayer_16.Residual_2.LayerNorm",
"text_model.encoder.layers.16.layer_norm1": "TransformerLayer_17.Residual_1.LayerNorm",
"text_model.encoder.layers.16.layer_norm2": "TransformerLayer_17.Residual_2.LayerNorm",
"text_model.encoder.layers.17.layer_norm1": "TransformerLayer_18.Residual_1.LayerNorm",
"text_model.encoder.layers.17.layer_norm2": "TransformerLayer_18.Residual_2.LayerNorm",
"text_model.encoder.layers.18.layer_norm1": "TransformerLayer_19.Residual_1.LayerNorm",
"text_model.encoder.layers.18.layer_norm2": "TransformerLayer_19.Residual_2.LayerNorm",
"text_model.encoder.layers.19.layer_norm1": "TransformerLayer_20.Residual_1.LayerNorm",
"text_model.encoder.layers.19.layer_norm2": "TransformerLayer_20.Residual_2.LayerNorm",
"text_model.encoder.layers.20.layer_norm1": "TransformerLayer_21.Residual_1.LayerNorm",
"text_model.encoder.layers.20.layer_norm2": "TransformerLayer_21.Residual_2.LayerNorm",
"text_model.encoder.layers.21.layer_norm1": "TransformerLayer_22.Residual_1.LayerNorm",
"text_model.encoder.layers.21.layer_norm2": "TransformerLayer_22.Residual_2.LayerNorm",
"text_model.encoder.layers.22.layer_norm1": "TransformerLayer_23.Residual_1.LayerNorm",
"text_model.encoder.layers.22.layer_norm2": "TransformerLayer_23.Residual_2.LayerNorm",
"text_model.final_layer_norm": "LayerNorm",
"text_model.encoder.layers.0.self_attn.q_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.0.self_attn.k_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.0.self_attn.v_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.0.self_attn.out_proj": "TransformerLayer_1.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.1.self_attn.q_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.1.self_attn.k_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.1.self_attn.v_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.1.self_attn.out_proj": "TransformerLayer_2.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.2.self_attn.q_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.2.self_attn.k_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.2.self_attn.v_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.2.self_attn.out_proj": "TransformerLayer_3.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.3.self_attn.q_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.3.self_attn.k_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.3.self_attn.v_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.3.self_attn.out_proj": "TransformerLayer_4.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.4.self_attn.q_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.4.self_attn.k_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.4.self_attn.v_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.4.self_attn.out_proj": "TransformerLayer_5.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.5.self_attn.q_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.5.self_attn.k_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.5.self_attn.v_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.5.self_attn.out_proj": "TransformerLayer_6.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.6.self_attn.q_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.6.self_attn.k_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.6.self_attn.v_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.6.self_attn.out_proj": "TransformerLayer_7.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.7.self_attn.q_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.7.self_attn.k_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.7.self_attn.v_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.7.self_attn.out_proj": "TransformerLayer_8.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.8.self_attn.q_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.8.self_attn.k_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.8.self_attn.v_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.8.self_attn.out_proj": "TransformerLayer_9.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.9.self_attn.q_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.9.self_attn.k_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.9.self_attn.v_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.9.self_attn.out_proj": "TransformerLayer_10.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.10.self_attn.q_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.10.self_attn.k_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.10.self_attn.v_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.10.self_attn.out_proj": "TransformerLayer_11.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.11.self_attn.q_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.11.self_attn.k_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.11.self_attn.v_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.11.self_attn.out_proj": "TransformerLayer_12.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.12.self_attn.q_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.12.self_attn.k_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.12.self_attn.v_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.12.self_attn.out_proj": "TransformerLayer_13.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.13.self_attn.q_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.13.self_attn.k_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.13.self_attn.v_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.13.self_attn.out_proj": "TransformerLayer_14.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.14.self_attn.q_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.14.self_attn.k_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.14.self_attn.v_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.14.self_attn.out_proj": "TransformerLayer_15.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.15.self_attn.q_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.15.self_attn.k_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.15.self_attn.v_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.15.self_attn.out_proj": "TransformerLayer_16.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.16.self_attn.q_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.16.self_attn.k_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.16.self_attn.v_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.16.self_attn.out_proj": "TransformerLayer_17.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.17.self_attn.q_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.17.self_attn.k_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.17.self_attn.v_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.17.self_attn.out_proj": "TransformerLayer_18.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.18.self_attn.q_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.18.self_attn.k_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.18.self_attn.v_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.18.self_attn.out_proj": "TransformerLayer_19.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.19.self_attn.q_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.19.self_attn.k_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.19.self_attn.v_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.19.self_attn.out_proj": "TransformerLayer_20.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.20.self_attn.q_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.20.self_attn.k_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.20.self_attn.v_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.20.self_attn.out_proj": "TransformerLayer_21.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.21.self_attn.q_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.21.self_attn.k_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.21.self_attn.v_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.21.self_attn.out_proj": "TransformerLayer_22.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.22.self_attn.q_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_1",
"text_model.encoder.layers.22.self_attn.k_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_2",
"text_model.encoder.layers.22.self_attn.v_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_3",
"text_model.encoder.layers.22.self_attn.out_proj": "TransformerLayer_23.Residual_1.SelfAttention.Linear",
"text_model.encoder.layers.0.mlp.fc1": "TransformerLayer_1.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.1.mlp.fc1": "TransformerLayer_2.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.2.mlp.fc1": "TransformerLayer_3.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.3.mlp.fc1": "TransformerLayer_4.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.4.mlp.fc1": "TransformerLayer_5.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.5.mlp.fc1": "TransformerLayer_6.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.6.mlp.fc1": "TransformerLayer_7.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.7.mlp.fc1": "TransformerLayer_8.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.8.mlp.fc1": "TransformerLayer_9.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.9.mlp.fc1": "TransformerLayer_10.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.10.mlp.fc1": "TransformerLayer_11.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.11.mlp.fc1": "TransformerLayer_12.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.12.mlp.fc1": "TransformerLayer_13.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.13.mlp.fc1": "TransformerLayer_14.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.14.mlp.fc1": "TransformerLayer_15.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.15.mlp.fc1": "TransformerLayer_16.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.16.mlp.fc1": "TransformerLayer_17.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.17.mlp.fc1": "TransformerLayer_18.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.18.mlp.fc1": "TransformerLayer_19.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.19.mlp.fc1": "TransformerLayer_20.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.20.mlp.fc1": "TransformerLayer_21.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.21.mlp.fc1": "TransformerLayer_22.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.22.mlp.fc1": "TransformerLayer_23.Residual_2.FeedForward.Linear_1",
"text_model.encoder.layers.0.mlp.fc2": "TransformerLayer_1.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.1.mlp.fc2": "TransformerLayer_2.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.2.mlp.fc2": "TransformerLayer_3.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.3.mlp.fc2": "TransformerLayer_4.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.4.mlp.fc2": "TransformerLayer_5.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.5.mlp.fc2": "TransformerLayer_6.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.6.mlp.fc2": "TransformerLayer_7.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.7.mlp.fc2": "TransformerLayer_8.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.8.mlp.fc2": "TransformerLayer_9.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.9.mlp.fc2": "TransformerLayer_10.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.10.mlp.fc2": "TransformerLayer_11.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.11.mlp.fc2": "TransformerLayer_12.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.12.mlp.fc2": "TransformerLayer_13.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.13.mlp.fc2": "TransformerLayer_14.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.14.mlp.fc2": "TransformerLayer_15.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.15.mlp.fc2": "TransformerLayer_16.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.16.mlp.fc2": "TransformerLayer_17.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.17.mlp.fc2": "TransformerLayer_18.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.18.mlp.fc2": "TransformerLayer_19.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.19.mlp.fc2": "TransformerLayer_20.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.20.mlp.fc2": "TransformerLayer_21.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.21.mlp.fc2": "TransformerLayer_22.Residual_2.FeedForward.Linear_2",
"text_model.encoder.layers.22.mlp.fc2": "TransformerLayer_23.Residual_2.FeedForward.Linear_2"
},
"source_aliases": {}
}

View File

@ -56,6 +56,13 @@ async def test_list_models():
response = client.get("/api/stablestudio/models")
assert response.status_code == 200
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
expected_model_ids = {
"sd15",
"openjourney-v1",
"openjourney-v2",
"openjourney-v4",
"modern-disney",
"redshift-diffusion",
}
model_ids = {m["id"] for m in response.json()}
assert model_ids == expected_model_ids

View File

@ -1,25 +0,0 @@
from imaginairy import config
from imaginairy.utils.model_manager import resolve_model_weights_config
def test_resolved_paths():
"""Test that the resolved model path is correct."""
model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
assert (
config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
)
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt",
default_model_architecture="sd15",
)
print(model_weights_config)
assert model_weights_config.aliases == []
assert "sd15" in model_weights_config.architecture.aliases
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
)
assert model_weights_config.aliases == []
assert "sd15-inpaint" in model_weights_config.architecture.aliases

View File

@ -0,0 +1,64 @@
import pytest
from imaginairy import config
from imaginairy.utils.model_manager import (
parse_diffusers_repo_url,
resolve_model_weights_config,
)
def test_resolved_paths():
"""Test that the resolved model path is correct."""
model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
assert (
config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
)
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt",
default_model_architecture="sd15",
)
assert model_weights_config.aliases == []
assert "sd15" in model_weights_config.architecture.aliases
model_weights_config = resolve_model_weights_config(
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
)
assert model_weights_config.aliases == []
assert "sd15-inpaint" in model_weights_config.architecture.aliases
hf_urls_cases = [
("", {}),
(
"https://huggingface.co/prompthero/zoom-v3/",
{"author": "prompthero", "repo": "zoom-v3", "ref": None},
),
(
"https://huggingface.co/prompthero/zoom-v3",
{"author": "prompthero", "repo": "zoom-v3", "ref": None},
),
(
"https://huggingface.co/prompthero/zoom-v3/tree/main",
{"author": "prompthero", "repo": "zoom-v3", "ref": "main"},
),
(
"https://huggingface.co/prompthero/zoom-v3/tree/main/",
{"author": "prompthero", "repo": "zoom-v3", "ref": "main"},
),
(
"https://huggingface.co/prompthero/zoom-v3/tree/6027e2fe2343bf0ed09a5883e027506950f182ed/",
{
"author": "prompthero",
"repo": "zoom-v3",
"ref": "6027e2fe2343bf0ed09a5883e027506950f182ed",
},
),
]
@pytest.mark.parametrize(("url", "expected"), hf_urls_cases)
def test_parse_diffusers_repo_url(url, expected):
result = parse_diffusers_repo_url(url)
assert result == expected