diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 3468be9..c774a02 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as: """ import json import time +from contextlib import suppress from typing import Dict, Optional, Union +import safetensors import torch import torch.nn as nn from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device from hivemind.utils.logging import get_logger from huggingface_hub import get_hf_file_metadata, hf_hub_url +from huggingface_hub.utils import EntryNotFoundError from transformers import PretrainedConfig from transformers.utils import get_file_from_repo @@ -90,11 +93,14 @@ def _load_state_dict_from_repo( if always_needs_auth(model_name) and token is None: token = True - index_file = get_file_from_repo( - model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir - ) - if index_file is not None: # Sharded model - with open(index_file) as f: + index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir) + if index_file.endswith(".index.json"): # Sharded model + path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir) + if path is None: + # _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared) + raise ValueError(f"Failed to get file {index_file}") + + with open(path) as f: index = json.load(f) filenames = { filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix) @@ -102,14 +108,15 @@ def _load_state_dict_from_repo( if not filenames: raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}") else: # Non-sharded model - filenames = {"pytorch_model.bin"} + filenames = {index_file} logger.debug(f"Loading {block_prefix}* from {filenames}") state_dict = {} for filename in filenames: - shard_state_dict = _load_state_dict_from_file( + shard_state_dict = _load_state_dict_from_repo_file( model_name, filename, + block_prefix=block_prefix, revision=revision, token=token, cache_dir=cache_dir, @@ -124,10 +131,42 @@ def _load_state_dict_from_repo( return state_dict -def _load_state_dict_from_file( +INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"] + + +def _find_index_file( + model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str +) -> str: + # If we have cached weights (e.g., Pickle from older Petals versions), reuse them + for filename in INDEX_FILES: + path = get_file_from_repo( + model_name, + filename, + revision=revision, + use_auth_token=token, + cache_dir=cache_dir, + local_files_only=True, + ) + if path is not None: + return filename + + # If we don't, prefer Safetensors when possible + # (we don't download files here since we can't account for max_disk_space in case of large files) + for filename in INDEX_FILES: + with suppress(EntryNotFoundError): + get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token) + return filename + + raise ValueError( + f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist" + ) + + +def _load_state_dict_from_repo_file( model_name: str, filename: str, *, + block_prefix: Optional[str] = None, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str, @@ -146,7 +185,7 @@ def _load_state_dict_from_file( local_files_only=True, ) if path is not None: - return torch.load(path, map_location="cpu") + return _load_state_dict_from_local_file(path, block_prefix=block_prefix) except Exception: logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True) @@ -171,7 +210,18 @@ def _load_state_dict_from_file( ) if path is None: raise RuntimeError(f"File {filename} does not exist in repo {model_name}") - return torch.load(path, map_location="cpu") + return _load_state_dict_from_local_file(path, block_prefix=block_prefix) except Exception as e: logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True) time.sleep(delay) + + +def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict: + if path.endswith(".bin"): + return torch.load(path, map_location="cpu") + + if path.endswith(".safetensors"): + with safetensors.safe_open(path, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)} + + raise ValueError(f"Unknown weight format: {path}")