mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
601a112dc3
+ renames and typehints
137 lines
4.4 KiB
Python
137 lines
4.4 KiB
Python
"""Functions for managing model weights information"""
|
|
|
|
import safetensors
|
|
|
|
from imaginairy.utils.downloads import get_cached_url_path
|
|
from imaginairy.utils.model_manager import (
|
|
open_weights,
|
|
resolve_model_weights_config,
|
|
)
|
|
from imaginairy.weight_management import utils
|
|
from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns
|
|
from imaginairy.weight_management.utils import save_model_info
|
|
|
|
|
|
def save_compvis_patterns():
|
|
model_weights_config = resolve_model_weights_config(
|
|
model_weights="openjourney-v1",
|
|
)
|
|
weights_path = get_cached_url_path(
|
|
model_weights_config.weights_location, category="weights"
|
|
)
|
|
|
|
with safetensors.safe_open(weights_path, "pytorch") as f:
|
|
weights_keys = f.keys()
|
|
|
|
text_encoder_prefix = "cond_stage_model.transformer.text_model"
|
|
text_encoder_keys = [k for k in weights_keys if k.startswith(text_encoder_prefix)]
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.TEXT_ENCODER,
|
|
format_name=utils.FORMAT_NAMES.COMPVIS,
|
|
weights_keys=text_encoder_keys,
|
|
)
|
|
|
|
vae_prefix = "first_stage_model"
|
|
vae_keys = [k for k in weights_keys if k.startswith(vae_prefix)]
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.VAE,
|
|
format_name=utils.FORMAT_NAMES.COMPVIS,
|
|
weights_keys=vae_keys,
|
|
)
|
|
|
|
unet_prefix = "model.diffusion_model"
|
|
unet_keys = [k for k in weights_keys if k.startswith(unet_prefix)]
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.UNET,
|
|
format_name=utils.FORMAT_NAMES.COMPVIS,
|
|
weights_keys=unet_keys,
|
|
)
|
|
|
|
|
|
def save_diffusers_patterns():
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.VAE,
|
|
format_name=utils.FORMAT_NAMES.DIFFUSERS,
|
|
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors",
|
|
)
|
|
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.UNET,
|
|
format_name=utils.FORMAT_NAMES.DIFFUSERS,
|
|
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors",
|
|
)
|
|
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.TEXT_ENCODER,
|
|
format_name=utils.FORMAT_NAMES.DIFFUSERS,
|
|
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/model.fp16.safetensors",
|
|
)
|
|
|
|
|
|
def save_lora_patterns():
|
|
filepath = "/Users/bryce/projects/sandbox-img-gen/refiners/weights/pytorch_lora_weights-refiners.safetensors"
|
|
state_dict = open_weights(filepath, device="cpu")
|
|
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.LORA,
|
|
format_name=utils.FORMAT_NAMES.REFINERS,
|
|
weights_keys=list(state_dict.keys()),
|
|
)
|
|
|
|
save_weight_info(
|
|
model_name=utils.MODEL_NAMES.SD15,
|
|
component_name=utils.COMPONENT_NAMES.LORA,
|
|
format_name=utils.FORMAT_NAMES.DIFFUSERS,
|
|
weights_url="https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin",
|
|
)
|
|
|
|
|
|
def save_weight_info(
|
|
model_name, component_name, format_name, weights_url=None, weights_keys=None
|
|
):
|
|
if weights_keys is None and weights_url is None:
|
|
msg = "Either weights_keys or weights_location must be provided"
|
|
raise ValueError(msg)
|
|
|
|
if weights_keys is None:
|
|
weights_path = get_cached_url_path(weights_url, category="weights")
|
|
|
|
state_dict = open_weights(weights_path, device="cpu")
|
|
weights_keys = list(state_dict.keys())
|
|
|
|
# prefixes = utils.prefixes_only(weights_keys)
|
|
|
|
save_model_info(
|
|
model_name=model_name,
|
|
component_name=component_name,
|
|
format_name=format_name,
|
|
info_type="weights_keys",
|
|
data=weights_keys,
|
|
)
|
|
|
|
patterns = find_state_dict_key_patterns(weights_keys)
|
|
save_model_info(
|
|
model_name=model_name,
|
|
component_name=component_name,
|
|
format_name=format_name,
|
|
info_type="patterns",
|
|
data=patterns,
|
|
)
|
|
|
|
|
|
def save_patterns():
|
|
save_lora_patterns()
|
|
# save_compvis_patterns()
|
|
# save_diffusers_patterns()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
save_patterns()
|