diff --git a/README.md b/README.md index 666da50..440d24f 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface - ## ChangeLog +**7.6.0** +- feature: ability to load safetensors + **7.5.0** - feature: 🎉 outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400` diff --git a/imaginairy/model_manager.py b/imaginairy/model_manager.py index 1a02057..d1e6a44 100644 --- a/imaginairy/model_manager.py +++ b/imaginairy/model_manager.py @@ -11,6 +11,7 @@ import torch from huggingface_hub import hf_hub_download as _hf_hub_download from huggingface_hub import try_to_load_from_cache from omegaconf import OmegaConf +from safetensors.torch import load_file from transformers.utils.hub import HfFolder from imaginairy import config as iconfig @@ -81,6 +82,14 @@ class MemoryAwareModel: gc.collect() +def load_tensors(tensorfile, map_location=None): + if tensorfile.endswith(".ckpt"): + return torch.load(tensorfile, map_location=map_location) + if tensorfile.endswith(".safetensors"): + return load_file(tensorfile, device=map_location) + raise ValueError(f"Unknown tensorfile type: {tensorfile}") + + def load_model_from_config(config, weights_location): if weights_location.startswith("http"): ckpt_path = get_cached_url_path(weights_location, category="weights") @@ -89,7 +98,7 @@ def load_model_from_config(config, weights_location): logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...") pl_sd = None try: - pl_sd = torch.load(ckpt_path, map_location="cpu") + pl_sd = load_tensors(ckpt_path, map_location="cpu") except FileNotFoundError as e: if e.errno == 2: logger.error( @@ -103,12 +112,15 @@ def load_model_from_config(config, weights_location): logger.warning("Corrupt checkpoint. deleting and re-downloading...") os.remove(ckpt_path) ckpt_path = get_cached_url_path(weights_location, category="weights") - pl_sd = torch.load(ckpt_path, map_location="cpu") + pl_sd = load_tensors(ckpt_path, map_location="cpu") if pl_sd is None: raise e if "global_step" in pl_sd: logger.debug(f"Global Step: {pl_sd['global_step']}") - state_dict = pl_sd["state_dict"] + if "state_dict" in pl_sd: + state_dict = pl_sd["state_dict"] + else: + state_dict = pl_sd model = instantiate_from_config(config.model) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)