Merge branch 'main' into memcache_touchup

pull/434/head
justheuristic 9 months ago committed by GitHub
commit 0846b4bb95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,10 +33,10 @@ python_requires = >=3.8
install_requires =
torch>=1.12
bitsandbytes==0.41.1
accelerate>=0.20.3
accelerate>=0.22.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.31.0,<5.0.0
transformers>=4.31.0,<5.0.0 # if you change this, please also change version assert in petals/__init__.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.9

@ -8,14 +8,17 @@ If necessary, one can rewrite this to implement a different behavior, such as:
"""
import json
import time
from contextlib import suppress
from typing import Dict, Optional, Union
import safetensors
import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from hivemind.utils.logging import get_logger
from huggingface_hub import get_hf_file_metadata, hf_hub_url
from huggingface_hub.utils import EntryNotFoundError
from transformers import PretrainedConfig
from transformers.utils import get_file_from_repo
@ -61,7 +64,7 @@ def load_pretrained_block(
)
# dummy load, check that keys match
report = block.load_state_dict(state_dict, strict=True)
report = block.load_state_dict(state_dict, strict=False)
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
for param_name, _ in block.named_parameters():
@ -71,7 +74,8 @@ def load_pretrained_block(
param = param.to(torch_dtype)
set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
logger.info(f"Loaded {model_name} block {block_index}, {report}")
logger.info(f"Loaded {model_name} block {block_index}")
logger.debug(f"Details: {report}")
return block
@ -90,11 +94,14 @@ def _load_state_dict_from_repo(
if always_needs_auth(model_name) and token is None:
token = True
index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
)
if index_file is not None: # Sharded model
with open(index_file) as f:
index_file = _find_index_file(model_name, revision=revision, token=token, cache_dir=cache_dir)
if index_file.endswith(".index.json"): # Sharded model
path = get_file_from_repo(model_name, filename=index_file, use_auth_token=token, cache_dir=cache_dir)
if path is None:
# _find_index_file() told that a file exists but we can't get it (e.g., it just disappeared)
raise ValueError(f"Failed to get file {index_file}")
with open(path) as f:
index = json.load(f)
filenames = {
filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
@ -102,14 +109,15 @@ def _load_state_dict_from_repo(
if not filenames:
raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
else: # Non-sharded model
filenames = {"pytorch_model.bin"}
filenames = {index_file}
logger.debug(f"Loading {block_prefix}* from {filenames}")
state_dict = {}
for filename in filenames:
shard_state_dict = _load_state_dict_from_file(
shard_state_dict = _load_state_dict_from_repo_file(
model_name,
filename,
block_prefix=block_prefix,
revision=revision,
token=token,
cache_dir=cache_dir,
@ -124,10 +132,42 @@ def _load_state_dict_from_repo(
return state_dict
def _load_state_dict_from_file(
INDEX_FILES = ["model.safetensors.index.json", "model.safetensors", "pytorch_model.bin.index.json", "pytorch_model.bin"]
def _find_index_file(
model_name: str, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str
) -> str:
# If we have cached weights (e.g., Pickle from older Petals versions), reuse them
for filename in INDEX_FILES:
path = get_file_from_repo(
model_name,
filename,
revision=revision,
use_auth_token=token,
cache_dir=cache_dir,
local_files_only=True,
)
if path is not None:
return filename
# If we don't, prefer Safetensors when possible
# (we don't download files here since we can't account for max_disk_space in case of large files)
for filename in INDEX_FILES:
with suppress(EntryNotFoundError):
get_hf_file_metadata(hf_hub_url(model_name, filename, revision=revision), token=token)
return filename
raise ValueError(
f"Repo {model_name} does not contain weights in a supported format: files {INDEX_FILES} do not exist"
)
def _load_state_dict_from_repo_file(
model_name: str,
filename: str,
*,
block_prefix: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
@ -146,7 +186,7 @@ def _load_state_dict_from_file(
local_files_only=True,
)
if path is not None:
return torch.load(path, map_location="cpu")
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception:
logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
@ -171,7 +211,18 @@ def _load_state_dict_from_file(
)
if path is None:
raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
return torch.load(path, map_location="cpu")
return _load_state_dict_from_local_file(path, block_prefix=block_prefix)
except Exception as e:
logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
time.sleep(delay)
def _load_state_dict_from_local_file(path: str, *, block_prefix: Optional[str] = None) -> StateDict:
if path.endswith(".bin"):
return torch.load(path, map_location="cpu")
if path.endswith(".safetensors"):
with safetensors.safe_open(path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys() if block_prefix is None or key.startswith(block_prefix)}
raise ValueError(f"Unknown weight format: {path}")

@ -157,15 +157,15 @@ class AdapterContextMixin:
using_adapter = AdapterContextMixin.using_adapter
class LoraLinear(lora.Linear, AdapterContextMixin):
class LoraLinear(AdapterContextMixin, lora.Linear):
"""LoRA linear layer that uses adapter selected via using_adapter"""
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""

Loading…
Cancel
Save