You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/utils/peft.py

209 lines
8.8 KiB
Python

import re
import time
from typing import List, Optional
import bitsandbytes as bnb
import torch.nn as nn
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType
logger = get_logger(__name__)
def check_peft_repository(repo_id: str) -> bool:
fs = HfFileSystem()
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
return len(list_of_files) > 0
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
tensors = dict()
is_tensors_found = dict()
common_layer_patter_re = (
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
)
with safe_open(filepath, framework=framework, device=device) as f:
for k in f.keys():
if re.match(common_layer_patter_re, k):
is_tensors_found[block_idx] = True
tensors[k] = f.get_tensor(k)
if not is_tensors_found.get(block_idx, False):
logger.warning(f"There is no peft weights for block {block_idx}")
return tensors
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
if config_path is None:
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
config = PeftConfig.from_json_file(config_path)
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
if weight_path is None:
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
if block_idx is None:
return config, load_file(weight_path)
return config, load_specific_module(block_idx, weight_path, device=device)
def load_peft(
repo_id: str,
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,
):
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
if not check_peft_repository(repo_id):
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
try:
with allow_cache_reads(cache_dir):
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception:
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
while True:
try:
with allow_cache_writes(cache_dir):
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
file_size = config_file_size + weight_file_size
if file_size is not None:
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
else:
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception as e:
logger.warning(
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
)
time.sleep(delay)
def create_lora_adapter(block, quant_type: QuantType):
for name, module in block.named_modules():
for child_name, child in module.named_children():
lora_wrapped_child = None
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
continue
if quant_type == QuantType.INT8:
kwargs = {
"has_fp16_weights": False,
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear8bitLt(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
elif quant_type == QuantType.NF4:
kwargs = {
"compress_statistics": True,
"quant_type": "nf4",
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear4bit(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = lora.Linear(
child_name,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
for name, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue
if child_name in peft_config["target_modules"] or (
isinstance(peft_config["target_modules"], str)
and re.fullmatch(peft_config["target_modules"], child_name)
):
is_lora_a_loaded = False
is_lora_b_loaded = False
for peft_key in peft_state_dict:
if peft_key.find(child_name) == -1:
continue
if adapter_name not in child.lora_A:
child.update_layer(
adapter_name,
peft_config["r"],
peft_config["lora_alpha"],
lora_dropout=peft_config["lora_dropout"],
init_lora_weights=peft_config["init_lora_weights"],
)
child.train(False)
if peft_config["lora_dropout"] > 0:
logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
for p in child.parameters():
p.requires_grad = False
if peft_key.endswith(".lora_A.weight"):
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
is_lora_a_loaded = True
elif peft_key.endswith(".lora_A.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
elif peft_key.endswith(".lora_B.weight"):
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
is_lora_b_loaded = True
elif peft_key.endswith(".lora_B.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
if is_lora_a_loaded and is_lora_b_loaded:
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")