pull/557/merge
Artem Chumachenko 2 months ago committed by GitHub
commit 94445cd0c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -47,7 +47,7 @@ install_requires =
cpufeature>=0.2.0; platform_machine == "x86_64"
packaging>=20.9
sentencepiece>=0.1.99
peft==0.5.0
peft==0.8.2
safetensors>=0.3.1
Dijkstar>=2.6.0

@ -61,7 +61,7 @@ def convert_block(
if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type)
create_lora_adapter(block)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,

@ -1,7 +1,7 @@
import contextlib
import re
import time
from typing import Optional, Sequence, Union
from typing import List, Optional, Sequence, Union
import bitsandbytes as bnb
import torch
@ -12,19 +12,21 @@ from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.config import PeftConfig
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
def check_peft_repository(repo_id: str) -> bool:
return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
@ -151,6 +153,18 @@ class AdapterContextMixin:
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
@property
def active_adapters(self):
return [self._context_active_adapter]
def set_adapter(self, adapter_names) -> None:
"""
In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,
this code removes this functionality.
Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
"""
pass
using_adapter = AdapterContextMixin.using_adapter
@ -158,60 +172,39 @@ using_adapter = AdapterContextMixin.using_adapter
class LoraLinear(AdapterContextMixin, lora.Linear):
"""LoRA linear layer that uses adapter selected via using_adapter"""
def __init__(self, base_layer, adapter_name: str):
nn.Module.__init__(self)
lora.LoraLayer.__init__(self, base_layer)
self._active_adapter = adapter_name
self.is_target_conv_1d_layer = False
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
class LoraLinear4bit(LoraLinear, lora.Linear4bit):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
def create_lora_adapter(block, quant_type: QuantType):
for _, module in block.named_modules():
def create_lora_adapter(block):
for module_name, module in block.named_modules():
if isinstance(module, LoraLinear):
continue
for child_name, child in module.named_children():
lora_wrapped_child = None
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
continue
if quant_type == QuantType.INT8:
kwargs = {
"has_fp16_weights": False,
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
elif quant_type == QuantType.NF4:
kwargs = {
"compress_statistics": True,
"quant_type": "nf4",
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
lora_wrapped_child.compute_dtype = child.compute_dtype
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = LoraLinear(
lora_class = None
if isinstance(child, nn.Linear):
lora_class = LoraLinear
elif isinstance(child, bnb.nn.Linear8bitLt):
lora_class = LoraLinear8bitLt
elif isinstance(child, bnb.nn.Linear4bit):
lora_class = LoraLinear4bit
if lora_class:
lora_wrapped_child = lora_class(
child,
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)
@ -240,6 +233,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
adapter_name,
peft_config["r"],
peft_config["lora_alpha"],
use_rslora=peft_config.get("use_rslora", False),
lora_dropout=peft_config["lora_dropout"],
init_lora_weights=peft_config["init_lora_weights"],
)
@ -275,7 +269,7 @@ def estimate_adapter_memory_per_block(
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)
create_lora_adapter(block)
for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)

Loading…
Cancel
Save