You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
3.0 KiB

import contextlib
import json
import os
import re
import tempfile
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
class FromPretrainedMixin:
def from_pretrained(
model_name_or_path: Union[str, os.PathLike, None],
low_cpu_mem_usage: Optional[bool] = None,
model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",
"low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
"torch_dtype (`str` or `torch.dtype`, *optional*)",
'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
_ignored_keys = ContextVar("ignored_keys", default=None)
def ignore_keys(patterns: List[str]):
token = _ignored_keys.set(patterns)
def patched_get_checkpoint_shard_files(
pretrained_model_name_or_path, index_filename, *args, **kwargs
) -> Tuple[List[str], dict]:
"""Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
should_ignore_keys = _ignored_keys.get() is not None
tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
with tempdir_ctx as tempdir:
if should_ignore_keys:
with open(index_filename) as f:
index = json.load(f)
n_original_shards = len(set(index["weight_map"].values()))
index["weight_map"] = {
param_name: filename
for param_name, filename in index["weight_map"].items()
if all(, param_name) is None for pattern in _ignored_keys.get())
n_loaded_shards = len(set(index["weight_map"].values()))
logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
# Replace the original index with a patched JSON, where ignored keys are removed
index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
with open(index_filename, "w") as f:
json.dump(index, f)
return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files