Support loading weights from Safetensors on server (#473)

pull/476/head
Alexander Borzunov 9 months ago committed by GitHub
parent 4f850996bb
commit a9b0e9ff1a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as:
""" """
import json import json
import time import time
from contextlib import suppress
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
import safetensors
import torch import torch
import torch.nn as nn import torch.nn as nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from huggingface_hub import get_hf_file_metadata, hf_hub_url from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
from transformers import PretrainedConfig from transformers import PretrainedConfig
from transformers.utils import get_file_from_repo from transformers.utils import get_file_from_repo
@ -90,11 +93,14 @@ def _load_state_dict_from_repo(
if always_needs_auth(model_name) and token is None: if always_needs_auth(model_name) and token is None:
token = True token = True
index_file = get_file_from_repo( index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir if index_file.endswith(".index.json"): # Sharded model
) path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
if index_file is not None: # Sharded model if path is None:
with open(index_file) as f: # _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
raise ValueError(f"Failed to get file {index_file}")
with open(path) as f:
index = json.load(f) index = json.load(f)
filenames = { filenames = {
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix) filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
@ -102,14 +108,15 @@ def _load_state_dict_from_repo(
if not filenames: if not filenames:
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}") raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
else: # Non-sharded model else: # Non-sharded model
filenames = {"pytorch_model.bin"} filenames = {index_file}
logger.debug(f"Loading {block_prefix}* from {filenames}") logger.debug(f"Loading {block_prefix}* from {filenames}")
state_dict = {} state_dict = {}
for filename in filenames: for filename in filenames:
shard_state_dict = _load_state_dict_from_file( shard_state_dict = _load_state_dict_from_repo_file(
model_name, model_name,
filename, filename,
block_prefix=block_prefix,
revision=revision, revision=revision,
token=token, token=token,
cache_dir=cache_dir, cache_dir=cache_dir,
@ -124,10 +131,42 @@ def _load_state_dict_from_repo(
return state_dict return state_dict
def _load_state_dict_from_file( INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
def _find_index_file(
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
) -> str:
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
for filename in INDEX_FILES:
path = get_file_from_repo(
model_name,
filename,
revision=revision,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=True,
)
if path is not None:
return filename
# If we don't, prefer Safetensors when possible
# (we don't download files here since we can't account for max_disk_space in case of large files)
for filename in INDEX_FILES:
with suppress(EntryNotFoundError):
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
return filename
raise ValueError(
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
)
def _load_state_dict_from_repo_file(
model_name: str, model_name: str,
filename: str, filename: str,
*, *,
block_prefix: Optional[str] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None, token: Optional[Union[str, bool]] = None,
cache_dir: str, cache_dir: str,
@ -146,7 +185,7 @@ def _load_state_dict_from_file(
local_files_only=True, local_files_only=True,
) )
if path is not None: if path is not None:
return torch.load(path, map_location="cpu") return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception: except Exception:
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True) logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
@ -171,7 +210,18 @@ def _load_state_dict_from_file(
) )
if path is None: if path is None:
raise RuntimeError(f"File {filename} does not exist in repo {model_name}") raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
return torch.load(path, map_location="cpu") return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception as e: except Exception as e:
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay) time.sleep(delay)
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
if path.endswith(".bin"):
return torch.load(path, map_location="cpu")
if path.endswith(".safetensors"):
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
raise ValueError(f"Unknown weight format: {path}")

Loading…
Cancel
Save