Support peft LoRA adapters (#335)
Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters. --------- Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com> Co-authored-by: justheuristic <justheuristic@gmail.com>pull/343/head
parent
dfc6578c8e
commit
b9f0a5467f
@ -0,0 +1,208 @@
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||
from peft.tuners import lora
|
||||
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
from petals.utils.misc import QuantType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def check_peft_repository(repo_id: str) -> bool:
|
||||
fs = HfFileSystem()
|
||||
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
||||
return len(list_of_files) > 0
|
||||
|
||||
|
||||
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||
tensors = dict()
|
||||
is_tensors_found = dict()
|
||||
common_layer_patter_re = (
|
||||
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||
)
|
||||
with safe_open(filepath, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
if re.match(common_layer_patter_re, k):
|
||||
is_tensors_found[block_idx] = True
|
||||
tensors[k] = f.get_tensor(k)
|
||||
if not is_tensors_found.get(block_idx, False):
|
||||
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||
return tensors
|
||||
|
||||
|
||||
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
|
||||
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
||||
if config_path is None:
|
||||
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||
config = PeftConfig.from_json_file(config_path)
|
||||
|
||||
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
||||
if weight_path is None:
|
||||
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||
if block_idx is None:
|
||||
return config, load_file(weight_path)
|
||||
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||
|
||||
|
||||
def load_peft(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
):
|
||||
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||
|
||||
if not check_peft_repository(repo_id):
|
||||
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
||||
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
||||
|
||||
file_size = config_file_size + weight_file_size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def create_lora_adapter(block, quant_type: QuantType):
|
||||
for name, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
lora_wrapped_child = None
|
||||
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||
continue
|
||||
if quant_type == QuantType.INT8:
|
||||
kwargs = {
|
||||
"has_fp16_weights": False,
|
||||
"threshold": 6.0,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = lora.Linear8bitLt(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
elif quant_type == QuantType.NF4:
|
||||
kwargs = {
|
||||
"compress_statistics": True,
|
||||
"quant_type": "nf4",
|
||||
"blocksize": 64,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = lora.Linear4bit(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
bias = hasattr(child, "bias") and child.bias is not None
|
||||
lora_wrapped_child = lora.Linear(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=bias,
|
||||
)
|
||||
if lora_wrapped_child:
|
||||
lora_wrapped_child.active_adapter = None
|
||||
lora_wrapped_child.weight = child.weight
|
||||
lora_wrapped_child.bias = child.bias
|
||||
for p in lora_wrapped_child.parameters():
|
||||
p.requires_grad = False
|
||||
setattr(module, child_name, lora_wrapped_child)
|
||||
|
||||
|
||||
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||
for name, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||
continue
|
||||
|
||||
if child_name in peft_config["target_modules"] or (
|
||||
isinstance(peft_config["target_modules"], str)
|
||||
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||
):
|
||||
is_lora_a_loaded = False
|
||||
is_lora_b_loaded = False
|
||||
for peft_key in peft_state_dict:
|
||||
if peft_key.find(child_name) == -1:
|
||||
continue
|
||||
|
||||
if adapter_name not in child.lora_A:
|
||||
child.update_layer(
|
||||
adapter_name,
|
||||
peft_config["r"],
|
||||
peft_config["lora_alpha"],
|
||||
lora_dropout=peft_config["lora_dropout"],
|
||||
init_lora_weights=peft_config["init_lora_weights"],
|
||||
)
|
||||
child.train(False)
|
||||
if peft_config["lora_dropout"] > 0:
|
||||
logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
|
||||
for p in child.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if peft_key.endswith(".lora_A.weight"):
|
||||
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
|
||||
is_lora_a_loaded = True
|
||||
elif peft_key.endswith(".lora_A.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
elif peft_key.endswith(".lora_B.weight"):
|
||||
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
||||
is_lora_b_loaded = True
|
||||
elif peft_key.endswith(".lora_B.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
|
||||
if is_lora_a_loaded and is_lora_b_loaded:
|
||||
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|
@ -0,0 +1,66 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from petals.utils.peft import check_peft_repository, load_peft
|
||||
|
||||
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||
TMP_CACHE_DIR = "tmp_cache/"
|
||||
|
||||
|
||||
def clear_dir(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
os.mkdir(path_to_dir)
|
||||
|
||||
|
||||
def dir_empty(path_to_dir):
|
||||
files = os.listdir(path_to_dir)
|
||||
return len(files) == 0
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_check_peft():
|
||||
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_noncached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
with pytest.raises(Exception):
|
||||
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_cached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_exists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_nonexists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(
|
||||
SAFE_PEFT_REPO,
|
||||
block_idx=1337,
|
||||
cache_dir=tmpdir,
|
||||
)
|
Loading…
Reference in New Issue