From f1e1b051d0407d65b7e9e5608e5f2cde93f7cb02 Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 22 Jul 2024 22:54:43 +0200 Subject: [PATCH] Update peft dependency, fix initialization and inference with new peft (#557) * Make fixes * lib number * Fix inference without adapter * Fix trainability * Fix versions * style * Update comments Co-authored-by: Max Ryabinin * Remove unnesc todo --------- Co-authored-by: Max Ryabinin Co-authored-by: justheuristic --- setup.cfg | 2 +- src/petals/utils/convert_block.py | 2 +- src/petals/utils/peft.py | 89 +++++++++++++++---------------- 3 files changed, 44 insertions(+), 49 deletions(-) diff --git a/setup.cfg b/setup.cfg index f4f7536..90f601f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = cpufeature>=0.2.0; platform_machine == "x86_64" packaging>=20.9 sentencepiece>=0.1.99 - peft==0.5.0 + peft==0.8.2 safetensors>=0.3.1 Dijkstar>=2.6.0 numpy<2 diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29..9dde414 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -61,7 +61,7 @@ def convert_block( if adapters: from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft - create_lora_adapter(block, quant_type=quant_type) + create_lora_adapter(block) for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft( adapter_name, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 149fda4..5d93ce6 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,7 +1,7 @@ import contextlib import re import time -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import bitsandbytes as bnb import torch @@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url from peft.config import PeftConfig from peft.tuners import lora -from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME +from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME from safetensors import safe_open from safetensors.torch import load_file from transformers.utils import get_file_from_repo @@ -25,6 +25,9 @@ from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) +COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"] + + def check_peft_repository(repo_id: str) -> bool: return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}") @@ -151,6 +154,18 @@ class AdapterContextMixin: def active_adapter(self, value: Optional[str]): assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" + @property + def active_adapters(self): + return [self._context_active_adapter] + + def set_adapter(self, adapter_names) -> None: + """ + In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus, + this code removes this functionality. + Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463 + """ + pass + using_adapter = AdapterContextMixin.using_adapter @@ -158,60 +173,39 @@ using_adapter = AdapterContextMixin.using_adapter class LoraLinear(AdapterContextMixin, lora.Linear): """LoRA linear layer that uses adapter selected via using_adapter""" + def __init__(self, base_layer, adapter_name: str): + nn.Module.__init__(self) + lora.LoraLayer.__init__(self, base_layer) + + self._active_adapter = adapter_name + self.is_target_conv_1d_layer = False + -class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt): +class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt): """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" -class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit): +class LoraLinear4bit(LoraLinear, lora.Linear4bit): """LoRA linear 4-bit that uses adapter selected via using_adapter""" -def create_lora_adapter(block, quant_type: QuantType): - for _, module in block.named_modules(): +def create_lora_adapter(block): + for module_name, module in block.named_modules(): + if isinstance(module, LoraLinear): + continue for child_name, child in module.named_children(): - lora_wrapped_child = None - if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): - continue - if quant_type == QuantType.INT8: - kwargs = { - "has_fp16_weights": False, - "threshold": 6.0, - "bias": hasattr(child, "bias") and child.bias is not None, - } - lora_wrapped_child = LoraLinear8bitLt( - AdapterContextMixin.ADAPTER_NOT_SET, - child.in_features, - child.out_features, - **kwargs, - ) - elif quant_type == QuantType.NF4: - kwargs = { - "compress_statistics": True, - "quant_type": "nf4", - "blocksize": 64, - "bias": hasattr(child, "bias") and child.bias is not None, - } - lora_wrapped_child = LoraLinear4bit( - AdapterContextMixin.ADAPTER_NOT_SET, - child.in_features, - child.out_features, - **kwargs, - ) - lora_wrapped_child.compute_dtype = child.compute_dtype - else: - bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = LoraLinear( + lora_class = None + if isinstance(child, nn.Linear): + lora_class = LoraLinear + elif isinstance(child, bnb.nn.Linear8bitLt): + lora_class = LoraLinear8bitLt + elif isinstance(child, bnb.nn.Linear4bit): + lora_class = LoraLinear4bit + if lora_class: + lora_wrapped_child = lora_class( + child, AdapterContextMixin.ADAPTER_NOT_SET, - child.in_features, - child.out_features, - bias=bias, ) - if lora_wrapped_child: - lora_wrapped_child.weight = child.weight - lora_wrapped_child.bias = child.bias - for p in lora_wrapped_child.parameters(): - p.requires_grad = False setattr(module, child_name, lora_wrapped_child) @@ -240,6 +234,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta adapter_name, peft_config["r"], peft_config["lora_alpha"], + use_rslora=peft_config.get("use_rslora", False), lora_dropout=peft_config["lora_dropout"], init_lora_weights=peft_config["init_lora_weights"], ) @@ -275,7 +270,7 @@ def estimate_adapter_memory_per_block( with init_empty_weights(include_buffers=True): block = get_model_block(block_config) base_block_parameters = sum(p.numel() for p in block.parameters()) - create_lora_adapter(block, quant_type=QuantType.NONE) + create_lora_adapter(block) for adapter in adapters: peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)