feature: ability to load safetensors

pull/170/head
Bryce 2 years ago committed by Bryce Drennan
parent c94dac76dc
commit 9b1d130f93

@ -252,6 +252,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog
**7.6.0**
- feature: ability to load safetensors
**7.5.0**
- feature: 🎉 outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400`

@ -11,6 +11,7 @@ import torch
from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import try_to_load_from_cache
from omegaconf import OmegaConf
from safetensors.torch import load_file
from transformers.utils.hub import HfFolder
from imaginairy import config as iconfig
@ -81,6 +82,14 @@ class MemoryAwareModel:
gc.collect()
def load_tensors(tensorfile, map_location=None):
if tensorfile.endswith(".ckpt"):
return torch.load(tensorfile, map_location=map_location)
if tensorfile.endswith(".safetensors"):
return load_file(tensorfile, device=map_location)
raise ValueError(f"Unknown tensorfile type: {tensorfile}")
def load_model_from_config(config, weights_location):
if weights_location.startswith("http"):
ckpt_path = get_cached_url_path(weights_location, category="weights")
@ -89,7 +98,7 @@ def load_model_from_config(config, weights_location):
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
pl_sd = None
try:
pl_sd = torch.load(ckpt_path, map_location="cpu")
pl_sd = load_tensors(ckpt_path, map_location="cpu")
except FileNotFoundError as e:
if e.errno == 2:
logger.error(
@ -103,12 +112,15 @@ def load_model_from_config(config, weights_location):
logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location, category="weights")
pl_sd = torch.load(ckpt_path, map_location="cpu")
pl_sd = load_tensors(ckpt_path, map_location="cpu")
if pl_sd is None:
raise e
if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}")
state_dict = pl_sd["state_dict"]
if "state_dict" in pl_sd:
state_dict = pl_sd["state_dict"]
else:
state_dict = pl_sd
model = instantiate_from_config(config.model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

Loading…
Cancel
Save