feature: ability to load safetensors

This commit is contained in:
Bryce 2023-01-17 22:43:23 -08:00 committed by Bryce Drennan
parent c94dac76dc
commit 9b1d130f93
2 changed files with 18 additions and 3 deletions

View File

@ -252,6 +252,9 @@ docker run -it --gpus all -v $HOME/.cache/huggingface:/root/.cache/huggingface -
## ChangeLog ## ChangeLog
**7.6.0**
- feature: ability to load safetensors
**7.5.0** **7.5.0**
- feature: 🎉 outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400` - feature: 🎉 outpainting. Examples: `--outpaint up10,down300,left50,right50` or `--outpaint all100` or `--outpaint u100,d200,l300,r400`

View File

@ -11,6 +11,7 @@ import torch
from huggingface_hub import hf_hub_download as _hf_hub_download from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import try_to_load_from_cache from huggingface_hub import try_to_load_from_cache
from omegaconf import OmegaConf from omegaconf import OmegaConf
from safetensors.torch import load_file
from transformers.utils.hub import HfFolder from transformers.utils.hub import HfFolder
from imaginairy import config as iconfig from imaginairy import config as iconfig
@ -81,6 +82,14 @@ class MemoryAwareModel:
gc.collect() gc.collect()
def load_tensors(tensorfile, map_location=None):
if tensorfile.endswith(".ckpt"):
return torch.load(tensorfile, map_location=map_location)
if tensorfile.endswith(".safetensors"):
return load_file(tensorfile, device=map_location)
raise ValueError(f"Unknown tensorfile type: {tensorfile}")
def load_model_from_config(config, weights_location): def load_model_from_config(config, weights_location):
if weights_location.startswith("http"): if weights_location.startswith("http"):
ckpt_path = get_cached_url_path(weights_location, category="weights") ckpt_path = get_cached_url_path(weights_location, category="weights")
@ -89,7 +98,7 @@ def load_model_from_config(config, weights_location):
logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...") logger.info(f"Loading model {ckpt_path} onto {get_device()} backend...")
pl_sd = None pl_sd = None
try: try:
pl_sd = torch.load(ckpt_path, map_location="cpu") pl_sd = load_tensors(ckpt_path, map_location="cpu")
except FileNotFoundError as e: except FileNotFoundError as e:
if e.errno == 2: if e.errno == 2:
logger.error( logger.error(
@ -103,12 +112,15 @@ def load_model_from_config(config, weights_location):
logger.warning("Corrupt checkpoint. deleting and re-downloading...") logger.warning("Corrupt checkpoint. deleting and re-downloading...")
os.remove(ckpt_path) os.remove(ckpt_path)
ckpt_path = get_cached_url_path(weights_location, category="weights") ckpt_path = get_cached_url_path(weights_location, category="weights")
pl_sd = torch.load(ckpt_path, map_location="cpu") pl_sd = load_tensors(ckpt_path, map_location="cpu")
if pl_sd is None: if pl_sd is None:
raise e raise e
if "global_step" in pl_sd: if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}") logger.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
state_dict = pl_sd["state_dict"] state_dict = pl_sd["state_dict"]
else:
state_dict = pl_sd
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)