|
|
|
@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as:
|
|
|
|
|
"""
|
|
|
|
|
import json
|
|
|
|
|
import time
|
|
|
|
|
from contextlib import suppress
|
|
|
|
|
from typing import Dict, Optional, Union
|
|
|
|
|
|
|
|
|
|
import safetensors
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
from accelerate.utils import set_module_tensor_to_device
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from huggingface_hub import get_hf_file_metadata, hf_hub_url
|
|
|
|
|
from huggingface_hub.utils import EntryNotFoundError
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
|
|
|
|
|
@ -90,11 +93,14 @@ def _load_state_dict_from_repo(
|
|
|
|
|
if always_needs_auth(model_name) and token is None:
|
|
|
|
|
token = True
|
|
|
|
|
|
|
|
|
|
index_file = get_file_from_repo(
|
|
|
|
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
|
|
|
|
)
|
|
|
|
|
if index_file is not None: # Sharded model
|
|
|
|
|
with open(index_file) as f:
|
|
|
|
|
index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
|
|
|
|
|
if index_file.endswith(".index.json"): # Sharded model
|
|
|
|
|
path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
|
|
|
|
|
if path is None:
|
|
|
|
|
# _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
|
|
|
|
|
raise ValueError(f"Failed to get file {index_file}")
|
|
|
|
|
|
|
|
|
|
with open(path) as f:
|
|
|
|
|
index = json.load(f)
|
|
|
|
|
filenames = {
|
|
|
|
|
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
|
|
|
|
@ -102,14 +108,15 @@ def _load_state_dict_from_repo(
|
|
|
|
|
if not filenames:
|
|
|
|
|
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
|
|
|
|
|
else: # Non-sharded model
|
|
|
|
|
filenames = {"pytorch_model.bin"}
|
|
|
|
|
filenames = {index_file}
|
|
|
|
|
logger.debug(f"Loading {block_prefix}* from {filenames}")
|
|
|
|
|
|
|
|
|
|
state_dict = {}
|
|
|
|
|
for filename in filenames:
|
|
|
|
|
shard_state_dict = _load_state_dict_from_file(
|
|
|
|
|
shard_state_dict = _load_state_dict_from_repo_file(
|
|
|
|
|
model_name,
|
|
|
|
|
filename,
|
|
|
|
|
block_prefix=block_prefix,
|
|
|
|
|
revision=revision,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
@ -124,10 +131,42 @@ def _load_state_dict_from_repo(
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state_dict_from_file(
|
|
|
|
|
INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _find_index_file(
|
|
|
|
|
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
|
|
|
|
|
) -> str:
|
|
|
|
|
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
|
|
|
|
|
for filename in INDEX_FILES:
|
|
|
|
|
path = get_file_from_repo(
|
|
|
|
|
model_name,
|
|
|
|
|
filename,
|
|
|
|
|
revision=revision,
|
|
|
|
|
use_auth_token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
local_files_only=True,
|
|
|
|
|
)
|
|
|
|
|
if path is not None:
|
|
|
|
|
return filename
|
|
|
|
|
|
|
|
|
|
# If we don't, prefer Safetensors when possible
|
|
|
|
|
# (we don't download files here since we can't account for max_disk_space in case of large files)
|
|
|
|
|
for filename in INDEX_FILES:
|
|
|
|
|
with suppress(EntryNotFoundError):
|
|
|
|
|
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
|
|
|
|
|
return filename
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state_dict_from_repo_file(
|
|
|
|
|
model_name: str,
|
|
|
|
|
filename: str,
|
|
|
|
|
*,
|
|
|
|
|
block_prefix: Optional[str] = None,
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
token: Optional[Union[str, bool]] = None,
|
|
|
|
|
cache_dir: str,
|
|
|
|
@ -146,7 +185,7 @@ def _load_state_dict_from_file(
|
|
|
|
|
local_files_only=True,
|
|
|
|
|
)
|
|
|
|
|
if path is not None:
|
|
|
|
|
return torch.load(path, map_location="cpu")
|
|
|
|
|
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
|
|
|
|
|
except Exception:
|
|
|
|
|
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
|
|
|
|
|
|
@ -171,7 +210,18 @@ def _load_state_dict_from_file(
|
|
|
|
|
)
|
|
|
|
|
if path is None:
|
|
|
|
|
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
|
|
|
|
|
return torch.load(path, map_location="cpu")
|
|
|
|
|
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
|
|
|
|
|
if path.endswith(".bin"):
|
|
|
|
|
return torch.load(path, map_location="cpu")
|
|
|
|
|
|
|
|
|
|
if path.endswith(".safetensors"):
|
|
|
|
|
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
|
|
|
|
|
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
|
|
|
|
|
|
|
|
|
|
raise ValueError(f"Unknown weight format: {path}")
|
|
|
|
|