You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/weight_management/generate_weight_info.py

137 lines
4.4 KiB
Python

"""Functions for managing model weights information"""
import safetensors
from imaginairy.utils.downloads import get_cached_url_path
from imaginairy.utils.model_manager import (
open_weights,
resolve_model_weights_config,
)
from imaginairy.weight_management import utils
from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns
from imaginairy.weight_management.utils import save_model_info
def save_compvis_patterns():
model_weights_config = resolve_model_weights_config(
model_weights="openjourney-v1",
)
weights_path = get_cached_url_path(
model_weights_config.weights_location, category="weights"
)
with safetensors.safe_open(weights_path, "pytorch") as f:
weights_keys = f.keys()
text_encoder_prefix = "cond_stage_model.transformer.text_model"
text_encoder_keys = [k for k in weights_keys if k.startswith(text_encoder_prefix)]
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.TEXT_ENCODER,
format_name=utils.FORMAT_NAMES.COMPVIS,
weights_keys=text_encoder_keys,
)
vae_prefix = "first_stage_model"
vae_keys = [k for k in weights_keys if k.startswith(vae_prefix)]
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.VAE,
format_name=utils.FORMAT_NAMES.COMPVIS,
weights_keys=vae_keys,
)
unet_prefix = "model.diffusion_model"
unet_keys = [k for k in weights_keys if k.startswith(unet_prefix)]
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.UNET,
format_name=utils.FORMAT_NAMES.COMPVIS,
weights_keys=unet_keys,
)
def save_diffusers_patterns():
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.VAE,
format_name=utils.FORMAT_NAMES.DIFFUSERS,
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors",
)
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.UNET,
format_name=utils.FORMAT_NAMES.DIFFUSERS,
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors",
)
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.TEXT_ENCODER,
format_name=utils.FORMAT_NAMES.DIFFUSERS,
weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/model.fp16.safetensors",
)
def save_lora_patterns():
filepath = "/Users/bryce/projects/sandbox-img-gen/refiners/weights/pytorch_lora_weights-refiners.safetensors"
state_dict = open_weights(filepath, device="cpu")
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.LORA,
format_name=utils.FORMAT_NAMES.REFINERS,
weights_keys=list(state_dict.keys()),
)
save_weight_info(
model_name=utils.MODEL_NAMES.SD15,
component_name=utils.COMPONENT_NAMES.LORA,
format_name=utils.FORMAT_NAMES.DIFFUSERS,
weights_url="https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin",
)
def save_weight_info(
model_name, component_name, format_name, weights_url=None, weights_keys=None
):
if weights_keys is None and weights_url is None:
msg = "Either weights_keys or weights_location must be provided"
raise ValueError(msg)
if weights_keys is None:
weights_path = get_cached_url_path(weights_url, category="weights")
state_dict = open_weights(weights_path, device="cpu")
weights_keys = list(state_dict.keys())
# prefixes = utils.prefixes_only(weights_keys)
save_model_info(
model_name=model_name,
component_name=component_name,
format_name=format_name,
info_type="weights_keys",
data=weights_keys,
)
patterns = find_state_dict_key_patterns(weights_keys)
save_model_info(
model_name=model_name,
component_name=component_name,
format_name=format_name,
info_type="patterns",
data=patterns,
)
def save_patterns():
save_lora_patterns()
# save_compvis_patterns()
# save_diffusers_patterns()
if __name__ == "__main__":
save_patterns()