mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
feature: support loading diffusers folder/based models from huggingface (#427)
This commit is contained in:
parent
50e796a3b7
commit
a2c38b3ec0
@ -150,6 +150,20 @@ MODEL_WEIGHT_CONFIGS = [
|
||||
weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors",
|
||||
defaults={"negative_prompt": "poor quality"},
|
||||
),
|
||||
ModelWeightsConfig(
|
||||
name="Modern Disney",
|
||||
aliases=["modern-disney", "modi", "modi15", "modern-disney-15"],
|
||||
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
|
||||
weights_location="https://huggingface.co/nitrosocke/mo-di-diffusion/tree/e3106d24aa8c37bf856257daea2ae789eabc4d70/",
|
||||
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
|
||||
),
|
||||
ModelWeightsConfig(
|
||||
name="Modern Disney",
|
||||
aliases=["redshift-diffusion", "red", "redshift-diffusion-15", "red15"],
|
||||
architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"],
|
||||
weights_location="https://huggingface.co/nitrosocke/redshift-diffusion/tree/80837fe18df05807861ab91c3bad3693c9342e4c/",
|
||||
defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT},
|
||||
),
|
||||
# Video Weights
|
||||
ModelWeightsConfig(
|
||||
name="Stable Video Diffusion",
|
||||
|
@ -10,6 +10,7 @@ from functools import lru_cache, wraps
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
HfFileSystem,
|
||||
HfFolder,
|
||||
hf_hub_download as _hf_hub_download,
|
||||
try_to_load_from_cache,
|
||||
@ -239,6 +240,29 @@ def get_diffusion_model_refiners(
|
||||
)
|
||||
|
||||
|
||||
hf_repo_url_pattern = re.compile(
|
||||
r"https://huggingface\.co/(?P<author>[^/]+)/(?P<repo>[^/]+)(/tree/(?P<ref>[a-z0-9]+))?/?$"
|
||||
)
|
||||
|
||||
|
||||
def parse_diffusers_repo_url(url: str) -> dict[str, str]:
|
||||
match = hf_repo_url_pattern.match(url)
|
||||
return match.groupdict() if match else {}
|
||||
|
||||
|
||||
def is_diffusers_repo_url(url: str) -> bool:
|
||||
return bool(parse_diffusers_repo_url(url))
|
||||
|
||||
|
||||
def normalize_diffusers_repo_url(url: str) -> str:
|
||||
data = parse_diffusers_repo_url(url)
|
||||
ref = data["ref"] or "main"
|
||||
normalized_url = (
|
||||
f"https://huggingface.co/{data['author']}/{data['repo']}/tree/{ref}/"
|
||||
)
|
||||
return normalized_url
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_diffusion_model_refiners(
|
||||
weights_location: str,
|
||||
@ -260,12 +284,18 @@ def _get_diffusion_model_refiners(
|
||||
global MOST_RECENTLY_LOADED_MODEL
|
||||
|
||||
device = device or get_device()
|
||||
|
||||
(
|
||||
vae_weights,
|
||||
unet_weights,
|
||||
text_encoder_weights,
|
||||
) = load_stable_diffusion_compvis_weights(weights_location)
|
||||
if is_diffusers_repo_url(weights_location):
|
||||
(
|
||||
vae_weights,
|
||||
unet_weights,
|
||||
text_encoder_weights,
|
||||
) = load_stable_diffusion_diffusers_weights(weights_location)
|
||||
else:
|
||||
(
|
||||
vae_weights,
|
||||
unet_weights,
|
||||
text_encoder_weights,
|
||||
) = load_stable_diffusion_compvis_weights(weights_location)
|
||||
|
||||
StableDiffusionCls: type[LatentDiffusionModel]
|
||||
if for_inpainting:
|
||||
@ -529,15 +559,40 @@ def extract_huggingface_repo_commit_file_from_url(url):
|
||||
return repo, commit_hash, filepath
|
||||
|
||||
|
||||
def download_diffusers_weights(repo, sub, filename):
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
def download_diffusers_weights(base_url, sub, filename=None):
|
||||
if filename is None:
|
||||
# select which weights to download. prefer fp16 safetensors
|
||||
data = parse_diffusers_repo_url(base_url)
|
||||
fs = HfFileSystem()
|
||||
filepaths = fs.ls(
|
||||
f"{data['author']}/{data['repo']}/{sub}", revision=data["ref"], detail=False
|
||||
)
|
||||
filepath = choose_diffusers_weights(filepaths)
|
||||
if not filepath:
|
||||
msg = f"Could not find any weights in {base_url}/{sub}"
|
||||
raise ValueError(msg)
|
||||
filename = filepath.split("/")[-1]
|
||||
url = f"{base_url}{sub}/{filename}".replace("/tree/", "/resolve/")
|
||||
new_path = get_cached_url_path(url, category="weights")
|
||||
return new_path
|
||||
|
||||
url = f"https://huggingface.co/{repo}/resolve/main/{sub}/{filename}"
|
||||
return get_cached_url_path(url, category="weights")
|
||||
|
||||
def choose_diffusers_weights(filenames):
|
||||
extension_priority = (".safetensors", ".bin", ".pth", ".pt")
|
||||
# filter out any files that don't have a valid extension
|
||||
filenames = [f for f in filenames if any(f.endswith(e) for e in extension_priority)]
|
||||
filenames_and_extension = [(f, os.path.splitext(f)[1]) for f in filenames]
|
||||
# sort by priority
|
||||
filenames_and_extension.sort(
|
||||
key=lambda x: ("fp16" not in x[0], extension_priority.index(x[1]))
|
||||
)
|
||||
if filenames_and_extension:
|
||||
return filenames_and_extension[0][0]
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache
|
||||
def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
|
||||
def load_stable_diffusion_diffusers_weights(base_url: str, device=None):
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.weight_management.conversion import cast_weights
|
||||
from imaginairy.weight_management.utils import (
|
||||
@ -546,11 +601,10 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
|
||||
MODEL_NAMES,
|
||||
)
|
||||
|
||||
base_url = normalize_diffusers_repo_url(base_url)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
vae_weights_path = download_diffusers_weights(
|
||||
repo=diffusers_repo, sub="vae", filename="diffusion_pytorch_model.safetensors"
|
||||
)
|
||||
vae_weights_path = download_diffusers_weights(base_url=base_url, sub="vae")
|
||||
vae_weights = open_weights(vae_weights_path, device=device)
|
||||
vae_weights = cast_weights(
|
||||
source_weights=vae_weights,
|
||||
@ -560,9 +614,7 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
|
||||
dest_format=FORMAT_NAMES.REFINERS,
|
||||
)
|
||||
|
||||
unet_weights_path = download_diffusers_weights(
|
||||
repo=diffusers_repo, sub="unet", filename="diffusion_pytorch_model.safetensors"
|
||||
)
|
||||
unet_weights_path = download_diffusers_weights(base_url=base_url, sub="unet")
|
||||
unet_weights = open_weights(unet_weights_path, device=device)
|
||||
unet_weights = cast_weights(
|
||||
source_weights=unet_weights,
|
||||
@ -573,7 +625,7 @@ def load_stable_diffusion_diffusers_weights(diffusers_repo, device=None):
|
||||
)
|
||||
|
||||
text_encoder_weights_path = download_diffusers_weights(
|
||||
repo=diffusers_repo, sub="text_encoder", filename="model.safetensors"
|
||||
base_url=base_url, sub="text_encoder"
|
||||
)
|
||||
text_encoder_weights = open_weights(text_encoder_weights_path, device=device)
|
||||
text_encoder_weights = cast_weights(
|
||||
@ -613,7 +665,6 @@ def open_weights(filepath, device=None):
|
||||
|
||||
def load_stable_diffusion_compvis_weights(weights_url):
|
||||
from imaginairy.utils import get_device
|
||||
from imaginairy.utils.model_manager import get_cached_url_path
|
||||
from imaginairy.weight_management.conversion import cast_weights
|
||||
from imaginairy.weight_management.utils import (
|
||||
COMPONENT_NAMES,
|
||||
|
@ -0,0 +1,200 @@
|
||||
{
|
||||
"ignorable_prefixes": [
|
||||
"text_model.embeddings"
|
||||
],
|
||||
"mapping": {
|
||||
"text_model.embeddings.token_embedding": "Sum.TokenEncoder",
|
||||
"text_model.embeddings.position_embedding": "Sum.PositionalEncoder.Embedding",
|
||||
"text_model.encoder.layers.0.layer_norm1": "TransformerLayer_1.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.0.layer_norm2": "TransformerLayer_1.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.1.layer_norm1": "TransformerLayer_2.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.1.layer_norm2": "TransformerLayer_2.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.2.layer_norm1": "TransformerLayer_3.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.2.layer_norm2": "TransformerLayer_3.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.3.layer_norm1": "TransformerLayer_4.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.3.layer_norm2": "TransformerLayer_4.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.4.layer_norm1": "TransformerLayer_5.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.4.layer_norm2": "TransformerLayer_5.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.5.layer_norm1": "TransformerLayer_6.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.5.layer_norm2": "TransformerLayer_6.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.6.layer_norm1": "TransformerLayer_7.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.6.layer_norm2": "TransformerLayer_7.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.7.layer_norm1": "TransformerLayer_8.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.7.layer_norm2": "TransformerLayer_8.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.8.layer_norm1": "TransformerLayer_9.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.8.layer_norm2": "TransformerLayer_9.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.9.layer_norm1": "TransformerLayer_10.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.9.layer_norm2": "TransformerLayer_10.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.10.layer_norm1": "TransformerLayer_11.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.10.layer_norm2": "TransformerLayer_11.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.11.layer_norm1": "TransformerLayer_12.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.11.layer_norm2": "TransformerLayer_12.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.12.layer_norm1": "TransformerLayer_13.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.12.layer_norm2": "TransformerLayer_13.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.13.layer_norm1": "TransformerLayer_14.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.13.layer_norm2": "TransformerLayer_14.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.14.layer_norm1": "TransformerLayer_15.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.14.layer_norm2": "TransformerLayer_15.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.15.layer_norm1": "TransformerLayer_16.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.15.layer_norm2": "TransformerLayer_16.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.16.layer_norm1": "TransformerLayer_17.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.16.layer_norm2": "TransformerLayer_17.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.17.layer_norm1": "TransformerLayer_18.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.17.layer_norm2": "TransformerLayer_18.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.18.layer_norm1": "TransformerLayer_19.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.18.layer_norm2": "TransformerLayer_19.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.19.layer_norm1": "TransformerLayer_20.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.19.layer_norm2": "TransformerLayer_20.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.20.layer_norm1": "TransformerLayer_21.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.20.layer_norm2": "TransformerLayer_21.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.21.layer_norm1": "TransformerLayer_22.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.21.layer_norm2": "TransformerLayer_22.Residual_2.LayerNorm",
|
||||
"text_model.encoder.layers.22.layer_norm1": "TransformerLayer_23.Residual_1.LayerNorm",
|
||||
"text_model.encoder.layers.22.layer_norm2": "TransformerLayer_23.Residual_2.LayerNorm",
|
||||
|
||||
"text_model.final_layer_norm": "LayerNorm",
|
||||
"text_model.encoder.layers.0.self_attn.q_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.0.self_attn.k_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.0.self_attn.v_proj": "TransformerLayer_1.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.0.self_attn.out_proj": "TransformerLayer_1.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.1.self_attn.q_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.1.self_attn.k_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.1.self_attn.v_proj": "TransformerLayer_2.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.1.self_attn.out_proj": "TransformerLayer_2.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.2.self_attn.q_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.2.self_attn.k_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.2.self_attn.v_proj": "TransformerLayer_3.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.2.self_attn.out_proj": "TransformerLayer_3.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.3.self_attn.q_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.3.self_attn.k_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.3.self_attn.v_proj": "TransformerLayer_4.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.3.self_attn.out_proj": "TransformerLayer_4.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.4.self_attn.q_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.4.self_attn.k_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.4.self_attn.v_proj": "TransformerLayer_5.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.4.self_attn.out_proj": "TransformerLayer_5.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.5.self_attn.q_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.5.self_attn.k_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.5.self_attn.v_proj": "TransformerLayer_6.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.5.self_attn.out_proj": "TransformerLayer_6.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.6.self_attn.q_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.6.self_attn.k_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.6.self_attn.v_proj": "TransformerLayer_7.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.6.self_attn.out_proj": "TransformerLayer_7.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.7.self_attn.q_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.7.self_attn.k_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.7.self_attn.v_proj": "TransformerLayer_8.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.7.self_attn.out_proj": "TransformerLayer_8.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.8.self_attn.q_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.8.self_attn.k_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.8.self_attn.v_proj": "TransformerLayer_9.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.8.self_attn.out_proj": "TransformerLayer_9.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.9.self_attn.q_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.9.self_attn.k_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.9.self_attn.v_proj": "TransformerLayer_10.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.9.self_attn.out_proj": "TransformerLayer_10.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.10.self_attn.q_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.10.self_attn.k_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.10.self_attn.v_proj": "TransformerLayer_11.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.10.self_attn.out_proj": "TransformerLayer_11.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.11.self_attn.q_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.11.self_attn.k_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.11.self_attn.v_proj": "TransformerLayer_12.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.11.self_attn.out_proj": "TransformerLayer_12.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.12.self_attn.q_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.12.self_attn.k_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.12.self_attn.v_proj": "TransformerLayer_13.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.12.self_attn.out_proj": "TransformerLayer_13.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.13.self_attn.q_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.13.self_attn.k_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.13.self_attn.v_proj": "TransformerLayer_14.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.13.self_attn.out_proj": "TransformerLayer_14.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.14.self_attn.q_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.14.self_attn.k_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.14.self_attn.v_proj": "TransformerLayer_15.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.14.self_attn.out_proj": "TransformerLayer_15.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.15.self_attn.q_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.15.self_attn.k_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.15.self_attn.v_proj": "TransformerLayer_16.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.15.self_attn.out_proj": "TransformerLayer_16.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.16.self_attn.q_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.16.self_attn.k_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.16.self_attn.v_proj": "TransformerLayer_17.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.16.self_attn.out_proj": "TransformerLayer_17.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.17.self_attn.q_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.17.self_attn.k_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.17.self_attn.v_proj": "TransformerLayer_18.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.17.self_attn.out_proj": "TransformerLayer_18.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.18.self_attn.q_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.18.self_attn.k_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.18.self_attn.v_proj": "TransformerLayer_19.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.18.self_attn.out_proj": "TransformerLayer_19.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.19.self_attn.q_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.19.self_attn.k_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.19.self_attn.v_proj": "TransformerLayer_20.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.19.self_attn.out_proj": "TransformerLayer_20.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.20.self_attn.q_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.20.self_attn.k_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.20.self_attn.v_proj": "TransformerLayer_21.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.20.self_attn.out_proj": "TransformerLayer_21.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.21.self_attn.q_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.21.self_attn.k_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.21.self_attn.v_proj": "TransformerLayer_22.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.21.self_attn.out_proj": "TransformerLayer_22.Residual_1.SelfAttention.Linear",
|
||||
"text_model.encoder.layers.22.self_attn.q_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_1",
|
||||
"text_model.encoder.layers.22.self_attn.k_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_2",
|
||||
"text_model.encoder.layers.22.self_attn.v_proj": "TransformerLayer_23.Residual_1.SelfAttention.Distribute.Linear_3",
|
||||
"text_model.encoder.layers.22.self_attn.out_proj": "TransformerLayer_23.Residual_1.SelfAttention.Linear",
|
||||
|
||||
|
||||
|
||||
|
||||
"text_model.encoder.layers.0.mlp.fc1": "TransformerLayer_1.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.1.mlp.fc1": "TransformerLayer_2.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.2.mlp.fc1": "TransformerLayer_3.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.3.mlp.fc1": "TransformerLayer_4.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.4.mlp.fc1": "TransformerLayer_5.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.5.mlp.fc1": "TransformerLayer_6.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.6.mlp.fc1": "TransformerLayer_7.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.7.mlp.fc1": "TransformerLayer_8.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.8.mlp.fc1": "TransformerLayer_9.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.9.mlp.fc1": "TransformerLayer_10.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.10.mlp.fc1": "TransformerLayer_11.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.11.mlp.fc1": "TransformerLayer_12.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.12.mlp.fc1": "TransformerLayer_13.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.13.mlp.fc1": "TransformerLayer_14.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.14.mlp.fc1": "TransformerLayer_15.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.15.mlp.fc1": "TransformerLayer_16.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.16.mlp.fc1": "TransformerLayer_17.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.17.mlp.fc1": "TransformerLayer_18.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.18.mlp.fc1": "TransformerLayer_19.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.19.mlp.fc1": "TransformerLayer_20.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.20.mlp.fc1": "TransformerLayer_21.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.21.mlp.fc1": "TransformerLayer_22.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.22.mlp.fc1": "TransformerLayer_23.Residual_2.FeedForward.Linear_1",
|
||||
"text_model.encoder.layers.0.mlp.fc2": "TransformerLayer_1.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.1.mlp.fc2": "TransformerLayer_2.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.2.mlp.fc2": "TransformerLayer_3.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.3.mlp.fc2": "TransformerLayer_4.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.4.mlp.fc2": "TransformerLayer_5.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.5.mlp.fc2": "TransformerLayer_6.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.6.mlp.fc2": "TransformerLayer_7.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.7.mlp.fc2": "TransformerLayer_8.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.8.mlp.fc2": "TransformerLayer_9.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.9.mlp.fc2": "TransformerLayer_10.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.10.mlp.fc2": "TransformerLayer_11.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.11.mlp.fc2": "TransformerLayer_12.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.12.mlp.fc2": "TransformerLayer_13.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.13.mlp.fc2": "TransformerLayer_14.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.14.mlp.fc2": "TransformerLayer_15.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.15.mlp.fc2": "TransformerLayer_16.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.16.mlp.fc2": "TransformerLayer_17.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.17.mlp.fc2": "TransformerLayer_18.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.18.mlp.fc2": "TransformerLayer_19.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.19.mlp.fc2": "TransformerLayer_20.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.20.mlp.fc2": "TransformerLayer_21.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.21.mlp.fc2": "TransformerLayer_22.Residual_2.FeedForward.Linear_2",
|
||||
"text_model.encoder.layers.22.mlp.fc2": "TransformerLayer_23.Residual_2.FeedForward.Linear_2"
|
||||
},
|
||||
"source_aliases": {}
|
||||
}
|
@ -56,6 +56,13 @@ async def test_list_models():
|
||||
response = client.get("/api/stablestudio/models")
|
||||
assert response.status_code == 200
|
||||
|
||||
expected_model_ids = {"sd15", "openjourney-v1", "openjourney-v2", "openjourney-v4"}
|
||||
expected_model_ids = {
|
||||
"sd15",
|
||||
"openjourney-v1",
|
||||
"openjourney-v2",
|
||||
"openjourney-v4",
|
||||
"modern-disney",
|
||||
"redshift-diffusion",
|
||||
}
|
||||
model_ids = {m["id"] for m in response.json()}
|
||||
assert model_ids == expected_model_ids
|
||||
|
@ -1,25 +0,0 @@
|
||||
from imaginairy import config
|
||||
from imaginairy.utils.model_manager import resolve_model_weights_config
|
||||
|
||||
|
||||
def test_resolved_paths():
|
||||
"""Test that the resolved model path is correct."""
|
||||
model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
|
||||
assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
|
||||
assert (
|
||||
config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
|
||||
)
|
||||
|
||||
model_weights_config = resolve_model_weights_config(
|
||||
model_weights="foo.ckpt",
|
||||
default_model_architecture="sd15",
|
||||
)
|
||||
print(model_weights_config)
|
||||
assert model_weights_config.aliases == []
|
||||
assert "sd15" in model_weights_config.architecture.aliases
|
||||
|
||||
model_weights_config = resolve_model_weights_config(
|
||||
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
|
||||
)
|
||||
assert model_weights_config.aliases == []
|
||||
assert "sd15-inpaint" in model_weights_config.architecture.aliases
|
64
tests/test_utils/test_model_manager.py
Normal file
64
tests/test_utils/test_model_manager.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
|
||||
from imaginairy import config
|
||||
from imaginairy.utils.model_manager import (
|
||||
parse_diffusers_repo_url,
|
||||
resolve_model_weights_config,
|
||||
)
|
||||
|
||||
|
||||
def test_resolved_paths():
|
||||
"""Test that the resolved model path is correct."""
|
||||
model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS)
|
||||
assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases
|
||||
assert (
|
||||
config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases
|
||||
)
|
||||
|
||||
model_weights_config = resolve_model_weights_config(
|
||||
model_weights="foo.ckpt",
|
||||
default_model_architecture="sd15",
|
||||
)
|
||||
assert model_weights_config.aliases == []
|
||||
assert "sd15" in model_weights_config.architecture.aliases
|
||||
|
||||
model_weights_config = resolve_model_weights_config(
|
||||
model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True
|
||||
)
|
||||
assert model_weights_config.aliases == []
|
||||
assert "sd15-inpaint" in model_weights_config.architecture.aliases
|
||||
|
||||
|
||||
hf_urls_cases = [
|
||||
("", {}),
|
||||
(
|
||||
"https://huggingface.co/prompthero/zoom-v3/",
|
||||
{"author": "prompthero", "repo": "zoom-v3", "ref": None},
|
||||
),
|
||||
(
|
||||
"https://huggingface.co/prompthero/zoom-v3",
|
||||
{"author": "prompthero", "repo": "zoom-v3", "ref": None},
|
||||
),
|
||||
(
|
||||
"https://huggingface.co/prompthero/zoom-v3/tree/main",
|
||||
{"author": "prompthero", "repo": "zoom-v3", "ref": "main"},
|
||||
),
|
||||
(
|
||||
"https://huggingface.co/prompthero/zoom-v3/tree/main/",
|
||||
{"author": "prompthero", "repo": "zoom-v3", "ref": "main"},
|
||||
),
|
||||
(
|
||||
"https://huggingface.co/prompthero/zoom-v3/tree/6027e2fe2343bf0ed09a5883e027506950f182ed/",
|
||||
{
|
||||
"author": "prompthero",
|
||||
"repo": "zoom-v3",
|
||||
"ref": "6027e2fe2343bf0ed09a5883e027506950f182ed",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("url", "expected"), hf_urls_cases)
|
||||
def test_parse_diffusers_repo_url(url, expected):
|
||||
result = parse_diffusers_repo_url(url)
|
||||
assert result == expected
|
Loading…
Reference in New Issue
Block a user