You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
8.8 KiB
Python
209 lines
8.8 KiB
Python
import re
|
|
import time
|
|
from typing import List, Optional
|
|
|
|
import bitsandbytes as bnb
|
|
import torch.nn as nn
|
|
from hivemind.utils.logging import get_logger
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
from peft.tuners import lora
|
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
from safetensors import safe_open
|
|
from safetensors.torch import load_file
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
from petals.utils.misc import QuantType
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
def check_peft_repository(repo_id: str) -> bool:
|
|
fs = HfFileSystem()
|
|
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
|
return len(list_of_files) > 0
|
|
|
|
|
|
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
|
tensors = dict()
|
|
is_tensors_found = dict()
|
|
common_layer_patter_re = (
|
|
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
|
)
|
|
with safe_open(filepath, framework=framework, device=device) as f:
|
|
for k in f.keys():
|
|
if re.match(common_layer_patter_re, k):
|
|
is_tensors_found[block_idx] = True
|
|
tensors[k] = f.get_tensor(k)
|
|
if not is_tensors_found.get(block_idx, False):
|
|
logger.warning(f"There is no peft weights for block {block_idx}")
|
|
return tensors
|
|
|
|
|
|
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
|
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
if config_path is None:
|
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
config = PeftConfig.from_json_file(config_path)
|
|
|
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
|
if weight_path is None:
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
if block_idx is None:
|
|
return config, load_file(weight_path)
|
|
return config, load_specific_module(block_idx, weight_path, device=device)
|
|
|
|
|
|
def load_peft(
|
|
repo_id: str,
|
|
block_idx: Optional[int] = None,
|
|
device: Optional[int] = None,
|
|
*,
|
|
revision: Optional[str] = None,
|
|
use_auth_token: Optional[str] = None,
|
|
cache_dir: str,
|
|
max_disk_space: Optional[int] = None,
|
|
delay: float = 30,
|
|
):
|
|
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
|
|
|
if not check_peft_repository(repo_id):
|
|
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
|
|
|
try:
|
|
with allow_cache_reads(cache_dir):
|
|
return get_adapter_from_repo(
|
|
repo_id,
|
|
block_idx,
|
|
device,
|
|
revision=revision,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=False,
|
|
)
|
|
except Exception:
|
|
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
|
while True:
|
|
try:
|
|
with allow_cache_writes(cache_dir):
|
|
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
|
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
|
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
|
|
|
file_size = config_file_size + weight_file_size
|
|
if file_size is not None:
|
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
else:
|
|
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
|
|
|
return get_adapter_from_repo(
|
|
repo_id,
|
|
block_idx,
|
|
device,
|
|
revision=revision,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
local_files_only=False,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
|
)
|
|
time.sleep(delay)
|
|
|
|
|
|
def create_lora_adapter(block, quant_type: QuantType):
|
|
for name, module in block.named_modules():
|
|
for child_name, child in module.named_children():
|
|
lora_wrapped_child = None
|
|
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
|
continue
|
|
if quant_type == QuantType.INT8:
|
|
kwargs = {
|
|
"has_fp16_weights": False,
|
|
"threshold": 6.0,
|
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
|
}
|
|
lora_wrapped_child = lora.Linear8bitLt(
|
|
child_name,
|
|
child.in_features,
|
|
child.out_features,
|
|
**kwargs,
|
|
)
|
|
elif quant_type == QuantType.NF4:
|
|
kwargs = {
|
|
"compress_statistics": True,
|
|
"quant_type": "nf4",
|
|
"blocksize": 64,
|
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
|
}
|
|
lora_wrapped_child = lora.Linear4bit(
|
|
child_name,
|
|
child.in_features,
|
|
child.out_features,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
bias = hasattr(child, "bias") and child.bias is not None
|
|
lora_wrapped_child = lora.Linear(
|
|
child_name,
|
|
child.in_features,
|
|
child.out_features,
|
|
bias=bias,
|
|
)
|
|
if lora_wrapped_child:
|
|
lora_wrapped_child.active_adapter = None
|
|
lora_wrapped_child.weight = child.weight
|
|
lora_wrapped_child.bias = child.bias
|
|
for p in lora_wrapped_child.parameters():
|
|
p.requires_grad = False
|
|
setattr(module, child_name, lora_wrapped_child)
|
|
|
|
|
|
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
|
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
|
for name, module in block.named_modules():
|
|
for child_name, child in module.named_children():
|
|
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
|
continue
|
|
|
|
if child_name in peft_config["target_modules"] or (
|
|
isinstance(peft_config["target_modules"], str)
|
|
and re.fullmatch(peft_config["target_modules"], child_name)
|
|
):
|
|
is_lora_a_loaded = False
|
|
is_lora_b_loaded = False
|
|
for peft_key in peft_state_dict:
|
|
if peft_key.find(child_name) == -1:
|
|
continue
|
|
|
|
if adapter_name not in child.lora_A:
|
|
child.update_layer(
|
|
adapter_name,
|
|
peft_config["r"],
|
|
peft_config["lora_alpha"],
|
|
lora_dropout=peft_config["lora_dropout"],
|
|
init_lora_weights=peft_config["init_lora_weights"],
|
|
)
|
|
child.train(False)
|
|
if peft_config["lora_dropout"] > 0:
|
|
logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
|
|
for p in child.parameters():
|
|
p.requires_grad = False
|
|
|
|
if peft_key.endswith(".lora_A.weight"):
|
|
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
is_lora_a_loaded = True
|
|
elif peft_key.endswith(".lora_A.bias"):
|
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
elif peft_key.endswith(".lora_B.weight"):
|
|
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
is_lora_b_loaded = True
|
|
elif peft_key.endswith(".lora_B.bias"):
|
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
|
if is_lora_a_loaded and is_lora_b_loaded:
|
|
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|