You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/utils/peft.py

289 lines
12 KiB
Python

import contextlib
import re
import time
from typing import Optional, Sequence, Union
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from accelerate import init_empty_weights
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.config import PeftConfig
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
def check_peft_repository(repo_id: str) -> bool:
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
tensors = dict()
is_tensors_found = dict()
common_layer_patter_re = (
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
)
with safe_open(filepath, framework=framework, device=device) as f:
for k in f.keys():
if re.match(common_layer_patter_re, k):
is_tensors_found[block_idx] = True
tensors[k] = f.get_tensor(k)
if not is_tensors_found.get(block_idx, False):
logger.warning(f"There is no peft weights for block {block_idx}")
return tensors
def get_adapter_from_repo(
repo_id: str,
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
token: Optional[Union[str, bool]] = None,
**kwargs,
):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs)
if config_path is None:
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
config = PeftConfig.from_json_file(config_path)
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs)
if weight_path is None:
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
if block_idx is None:
return config, load_file(weight_path)
return config, load_specific_module(block_idx, weight_path, device=device)
def load_peft(
repo_id: str,
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,
):
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
if not check_peft_repository(repo_id):
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
try:
with allow_cache_reads(cache_dir):
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
token=token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception:
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
while True:
try:
with allow_cache_writes(cache_dir):
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
config_file_size = get_hf_file_metadata(config_url, token=token).size
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
weight_file_size = get_hf_file_metadata(weight_url, token=token).size
file_size = config_file_size + weight_file_size
if file_size is not None:
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
else:
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
token=token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception as e:
logger.warning(
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
)
time.sleep(delay)
class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
_context_active_adapter = ADAPTER_NOT_SET
@staticmethod
@contextlib.contextmanager
def using_adapter(active_adapter: Optional[str]):
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
try:
yield
finally:
AdapterContextMixin._context_active_adapter = prev
@property
def active_adapter(self):
if self._context_active_adapter == self.ADAPTER_NOT_SET:
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
return self._context_active_adapter
@active_adapter.setter
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
using_adapter = AdapterContextMixin.using_adapter
class LoraLinear(AdapterContextMixin, lora.Linear):
"""LoRA linear layer that uses adapter selected via using_adapter"""
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
def create_lora_adapter(block, quant_type: QuantType):
for _, module in block.named_modules():
for child_name, child in module.named_children():
lora_wrapped_child = None
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
continue
if quant_type == QuantType.INT8:
kwargs = {
"has_fp16_weights": False,
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
elif quant_type == QuantType.NF4:
kwargs = {
"compress_statistics": True,
"quant_type": "nf4",
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
lora_wrapped_child.compute_dtype = child.compute_dtype
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = LoraLinear(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
if peft_config["lora_dropout"] > 0:
logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
for _, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue
if child_name in peft_config["target_modules"] or (
isinstance(peft_config["target_modules"], str)
and re.fullmatch(peft_config["target_modules"], child_name)
):
is_lora_a_loaded = False
is_lora_b_loaded = False
for peft_key in peft_state_dict:
if child_name not in peft_key:
continue
if adapter_name not in child.lora_A:
child.update_layer(
adapter_name,
peft_config["r"],
peft_config["lora_alpha"],
lora_dropout=peft_config["lora_dropout"],
init_lora_weights=peft_config["init_lora_weights"],
)
child.train(False)
for p in child.parameters():
p.requires_grad = False
if peft_key.endswith(".lora_A.weight"):
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_a_loaded = True
elif peft_key.endswith(".lora_A.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
elif peft_key.endswith(".lora_B.weight"):
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_b_loaded = True
elif peft_key.endswith(".lora_B.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
if is_lora_a_loaded and is_lora_b_loaded:
logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
elif is_lora_a_loaded or is_lora_b_loaded:
raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig,
torch_dtype: Optional[torch.dtype],
adapters: Sequence[str],
**load_peft_kwargs,
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)
for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
add_adapter_to_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
return adapter_parameters * bytes_per_parameter