|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
import contextlib
|
|
|
|
|
import re
|
|
|
|
|
import time
|
|
|
|
|
from typing import Optional, Sequence, Union
|
|
|
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
|
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
import torch
|
|
|
|
@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
|
|
|
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
|
|
from peft.config import PeftConfig
|
|
|
|
|
from peft.tuners import lora
|
|
|
|
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
|
|
|
|
|
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
|
|
|
|
|
from safetensors import safe_open
|
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
@ -25,6 +25,9 @@ from petals.utils.misc import get_size_in_bytes
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_peft_repository(repo_id: str) -> bool:
|
|
|
|
|
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
|
|
|
|
|
|
|
|
|
@ -151,6 +154,18 @@ class AdapterContextMixin:
|
|
|
|
|
def active_adapter(self, value: Optional[str]):
|
|
|
|
|
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def active_adapters(self):
|
|
|
|
|
return [self._context_active_adapter]
|
|
|
|
|
|
|
|
|
|
def set_adapter(self, adapter_names) -> None:
|
|
|
|
|
"""
|
|
|
|
|
In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,
|
|
|
|
|
this code removes this functionality.
|
|
|
|
|
Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
using_adapter = AdapterContextMixin.using_adapter
|
|
|
|
|
|
|
|
|
@ -158,60 +173,39 @@ using_adapter = AdapterContextMixin.using_adapter
|
|
|
|
|
class LoraLinear(AdapterContextMixin, lora.Linear):
|
|
|
|
|
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, base_layer, adapter_name: str):
|
|
|
|
|
nn.Module.__init__(self)
|
|
|
|
|
lora.LoraLayer.__init__(self, base_layer)
|
|
|
|
|
|
|
|
|
|
self._active_adapter = adapter_name
|
|
|
|
|
self.is_target_conv_1d_layer = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
|
|
|
|
|
class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
|
|
|
|
|
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
|
|
|
|
|
class LoraLinear4bit(LoraLinear, lora.Linear4bit):
|
|
|
|
|
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_lora_adapter(block, quant_type: QuantType):
|
|
|
|
|
for _, module in block.named_modules():
|
|
|
|
|
def create_lora_adapter(block):
|
|
|
|
|
for module_name, module in block.named_modules():
|
|
|
|
|
if isinstance(module, LoraLinear):
|
|
|
|
|
continue
|
|
|
|
|
for child_name, child in module.named_children():
|
|
|
|
|
lora_wrapped_child = None
|
|
|
|
|
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
|
|
|
|
continue
|
|
|
|
|
if quant_type == QuantType.INT8:
|
|
|
|
|
kwargs = {
|
|
|
|
|
"has_fp16_weights": False,
|
|
|
|
|
"threshold": 6.0,
|
|
|
|
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
|
|
}
|
|
|
|
|
lora_wrapped_child = LoraLinear8bitLt(
|
|
|
|
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
elif quant_type == QuantType.NF4:
|
|
|
|
|
kwargs = {
|
|
|
|
|
"compress_statistics": True,
|
|
|
|
|
"quant_type": "nf4",
|
|
|
|
|
"blocksize": 64,
|
|
|
|
|
"bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
|
|
}
|
|
|
|
|
lora_wrapped_child = LoraLinear4bit(
|
|
|
|
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
lora_wrapped_child.compute_dtype = child.compute_dtype
|
|
|
|
|
else:
|
|
|
|
|
bias = hasattr(child, "bias") and child.bias is not None
|
|
|
|
|
lora_wrapped_child = LoraLinear(
|
|
|
|
|
lora_class = None
|
|
|
|
|
if isinstance(child, nn.Linear):
|
|
|
|
|
lora_class = LoraLinear
|
|
|
|
|
elif isinstance(child, bnb.nn.Linear8bitLt):
|
|
|
|
|
lora_class = LoraLinear8bitLt
|
|
|
|
|
elif isinstance(child, bnb.nn.Linear4bit):
|
|
|
|
|
lora_class = LoraLinear4bit
|
|
|
|
|
if lora_class:
|
|
|
|
|
lora_wrapped_child = lora_class(
|
|
|
|
|
child,
|
|
|
|
|
AdapterContextMixin.ADAPTER_NOT_SET,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
bias=bias,
|
|
|
|
|
)
|
|
|
|
|
if lora_wrapped_child:
|
|
|
|
|
lora_wrapped_child.weight = child.weight
|
|
|
|
|
lora_wrapped_child.bias = child.bias
|
|
|
|
|
for p in lora_wrapped_child.parameters():
|
|
|
|
|
p.requires_grad = False
|
|
|
|
|
setattr(module, child_name, lora_wrapped_child)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -240,6 +234,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
|
|
|
|
|
adapter_name,
|
|
|
|
|
peft_config["r"],
|
|
|
|
|
peft_config["lora_alpha"],
|
|
|
|
|
use_rslora=peft_config.get("use_rslora", False),
|
|
|
|
|
lora_dropout=peft_config["lora_dropout"],
|
|
|
|
|
init_lora_weights=peft_config["init_lora_weights"],
|
|
|
|
|
)
|
|
|
|
@ -275,7 +270,7 @@ def estimate_adapter_memory_per_block(
|
|
|
|
|
with init_empty_weights(include_buffers=True):
|
|
|
|
|
block = get_model_block(block_config)
|
|
|
|
|
base_block_parameters = sum(p.numel() for p in block.parameters())
|
|
|
|
|
create_lora_adapter(block, quant_type=QuantType.NONE)
|
|
|
|
|
create_lora_adapter(block)
|
|
|
|
|
|
|
|
|
|
for adapter in adapters:
|
|
|
|
|
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)
|
|
|
|
|