import contextlib import re import time from typing import Optional, Sequence, Union import bitsandbytes as bnb import torch import torch.nn as nn import transformers from accelerate import init_empty_weights from hivemind.utils.logging import get_logger from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url from peft.config import PeftConfig from peft.tuners import lora from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) def check_peft_repository(repo_id: str) -> bool: fs = HfFileSystem() list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False) return len(list_of_files) > 0 def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None): tensors = dict() is_tensors_found = dict() common_layer_patter_re = ( ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+" ) with safe_open(filepath, framework=framework, device=device) as f: for k in f.keys(): if re.match(common_layer_patter_re, k): is_tensors_found[block_idx] = True tensors[k] = f.get_tensor(k) if not is_tensors_found.get(block_idx, False): logger.warning(f"There is no peft weights for block {block_idx}") return tensors def get_adapter_from_repo( repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, *, token: Optional[Union[str, bool]] = None, **kwargs, ): config_path = get_file_from_repo(repo_id, CONFIG_NAME, use_auth_token=token, **kwargs) if config_path is None: raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}") config = PeftConfig.from_json_file(config_path) weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, use_auth_token=token, **kwargs) if weight_path is None: raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}") if block_idx is None: return config, load_file(weight_path) return config, load_specific_module(block_idx, weight_path, device=device) def load_peft( repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, *, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, cache_dir: str, max_disk_space: Optional[int] = None, delay: float = 30, ): # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here if not check_peft_repository(repo_id): raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.") try: with allow_cache_reads(cache_dir): return get_adapter_from_repo( repo_id, block_idx, device, revision=revision, token=token, cache_dir=cache_dir, local_files_only=False, ) except Exception: logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True) while True: try: with allow_cache_writes(cache_dir): config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision) config_file_size = get_hf_file_metadata(config_url, token=token).size weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) weight_file_size = get_hf_file_metadata(weight_url, token=token).size file_size = config_file_size + weight_file_size if file_size is not None: free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) else: logger.warning(f"Failed to fetch size from peft repo {repo_id}") return get_adapter_from_repo( repo_id, block_idx, device, revision=revision, token=token, cache_dir=cache_dir, local_files_only=False, ) except Exception as e: logger.warning( f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True ) time.sleep(delay) def get_estimated_peft_module_size( repo_id: str, revision: Optional[str] = None, token: Optional[Union[str, bool]] = None, ): weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) return get_hf_file_metadata(weight_url, token=token).size class AdapterContextMixin: """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" _context_active_adapter = ADAPTER_NOT_SET @staticmethod @contextlib.contextmanager def using_adapter(active_adapter: Optional[str]): prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter try: yield finally: AdapterContextMixin._context_active_adapter = prev @property def active_adapter(self): if self._context_active_adapter == self.ADAPTER_NOT_SET: logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") return self._context_active_adapter @active_adapter.setter def active_adapter(self, value: Optional[str]): assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" using_adapter = AdapterContextMixin.using_adapter class LoraLinear(AdapterContextMixin, lora.Linear): """LoRA linear layer that uses adapter selected via using_adapter""" class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt): """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit): """LoRA linear 4-bit that uses adapter selected via using_adapter""" def create_lora_adapter(block, quant_type: QuantType): for _, module in block.named_modules(): for child_name, child in module.named_children(): lora_wrapped_child = None if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): continue if quant_type == QuantType.INT8: kwargs = { "has_fp16_weights": False, "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } lora_wrapped_child = LoraLinear8bitLt( AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) elif quant_type == QuantType.NF4: kwargs = { "compress_statistics": True, "quant_type": "nf4", "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } lora_wrapped_child = LoraLinear4bit( AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) lora_wrapped_child.compute_dtype = child.compute_dtype else: bias = hasattr(child, "bias") and child.bias is not None lora_wrapped_child = LoraLinear( AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, ) if lora_wrapped_child: lora_wrapped_child.weight = child.weight lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters(): p.requires_grad = False setattr(module, child_name, lora_wrapped_child) def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict): assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters" if peft_config["lora_dropout"] > 0: logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout") for _, module in block.named_modules(): for child_name, child in module.named_children(): if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): continue if child_name in peft_config["target_modules"] or ( isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name) ): is_lora_a_loaded = False is_lora_b_loaded = False for peft_key in peft_state_dict: if child_name not in peft_key: continue if adapter_name not in child.lora_A: child.update_layer( adapter_name, peft_config["r"], peft_config["lora_alpha"], lora_dropout=peft_config["lora_dropout"], init_lora_weights=peft_config["init_lora_weights"], ) child.train(False) for p in child.parameters(): p.requires_grad = False if peft_key.endswith(".lora_A.weight"): child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_a_loaded = True elif peft_key.endswith(".lora_A.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") elif peft_key.endswith(".lora_B.weight"): child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key] is_lora_b_loaded = True elif peft_key.endswith(".lora_B.bias"): raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") if is_lora_a_loaded and is_lora_b_loaded: logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}") elif is_lora_a_loaded or is_lora_b_loaded: raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}") logger.info(f"Loaded adapter {adapter_name} for block {block_index}") def remove_adapter_from_block(block, adapter_name): for _, module in block.named_modules(): for child_name, child in module.named_children(): if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): continue if adapter_name in child.lora_A: del child.lora_A[adapter_name] if adapter_name in child.lora_B: del child.lora_B[adapter_name] # TODO: check is this needed if torch.cuda.is_available(): torch.cuda.empty_cache() def estimate_adapter_memory_per_block( block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **load_peft_kwargs, ) -> int: """Get the number of extra bytes used to store a set of adapters per given block""" with init_empty_weights(include_buffers=True): block = block_config.block_class(block_config) base_block_parameters = sum(p.numel() for p in block.parameters()) create_lora_adapter(block, quant_type=QuantType.NONE) for adapter in adapters: peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs) assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now" add_adapter_to_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict ) adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype)) return adapter_parameters * bytes_per_parameter