|
|
@ -11,6 +11,7 @@ import torch
|
|
|
|
from huggingface_hub import hf_hub_download as _hf_hub_download
|
|
|
|
from huggingface_hub import hf_hub_download as _hf_hub_download
|
|
|
|
from huggingface_hub import try_to_load_from_cache
|
|
|
|
from huggingface_hub import try_to_load_from_cache
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
from transformers.utils.hub import HfFolder
|
|
|
|
from transformers.utils.hub import HfFolder
|
|
|
|
|
|
|
|
|
|
|
|
from imaginairy import config as iconfig
|
|
|
|
from imaginairy import config as iconfig
|
|
|
@ -81,6 +82,14 @@ class MemoryAwareModel:
|
|
|
|
gc.collect()
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_tensors(tensorfile, map_location=None):
|
|
|
|
|
|
|
|
if tensorfile.endswith(".ckpt"):
|
|
|
|
|
|
|
|
return torch.load(tensorfile, map_location=map_location)
|
|
|
|
|
|
|
|
if tensorfile.endswith(".safetensors"):
|
|
|
|
|
|
|
|
return load_file(tensorfile, device=map_location)
|
|
|
|
|
|
|
|
raise ValueError(f"Unknown tensorfile type: {tensorfile}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_model_from_config(config, weights_location):
|
|
|
|
def load_model_from_config(config, weights_location):
|
|
|
|
if weights_location.startswith("http"):
|
|
|
|
if weights_location.startswith("http"):
|
|
|
|
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
|
|
|
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
|
|
@ -89,7 +98,7 @@ def load_model_from_config(config, weights_location):
|
|
|
|
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
|
|
|
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
|
|
|
|
pl_sd = None
|
|
|
|
pl_sd = None
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
pl_sd = load_tensors(ckpt_path, map_location="cpu")
|
|
|
|
except FileNotFoundError as e:
|
|
|
|
except FileNotFoundError as e:
|
|
|
|
if e.errno == 2:
|
|
|
|
if e.errno == 2:
|
|
|
|
logger.error(
|
|
|
|
logger.error(
|
|
|
@ -103,12 +112,15 @@ def load_model_from_config(config, weights_location):
|
|
|
|
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
|
|
|
|
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
|
|
|
|
os.remove(ckpt_path)
|
|
|
|
os.remove(ckpt_path)
|
|
|
|
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
|
|
|
ckpt_path = get_cached_url_path(weights_location, category="weights")
|
|
|
|
pl_sd = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
pl_sd = load_tensors(ckpt_path, map_location="cpu")
|
|
|
|
if pl_sd is None:
|
|
|
|
if pl_sd is None:
|
|
|
|
raise e
|
|
|
|
raise e
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
if "global_step" in pl_sd:
|
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
|
|
|
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
|
|
|
state_dict = pl_sd["state_dict"]
|
|
|
|
if "state_dict" in pl_sd:
|
|
|
|
|
|
|
|
state_dict = pl_sd["state_dict"]
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
state_dict = pl_sd
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
model = instantiate_from_config(config.model)
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|