|
|
|
@ -1,9 +1,16 @@
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import time
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
from typing import List, Optional, Sequence
|
|
|
|
|
|
|
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
|
|
|
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
import peft
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import transformers
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
|
|
from peft.tuners import lora
|
|
|
|
@ -12,6 +19,8 @@ from safetensors import safe_open
|
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
|
|
|
|
|
|
from petals.client.ptune import force_non_empty_weights
|
|
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
|
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
|
|
|
from petals.utils.misc import QuantType
|
|
|
|
|
|
|
|
|
@ -194,15 +203,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
|
|
|
|
|
p.requires_grad = False
|
|
|
|
|
|
|
|
|
|
if peft_key.endswith(".lora_A.weight"):
|
|
|
|
|
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
|
|
|
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
|
|
|
|
|
is_lora_a_loaded = True
|
|
|
|
|
elif peft_key.endswith(".lora_A.bias"):
|
|
|
|
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
|
|
elif peft_key.endswith(".lora_B.weight"):
|
|
|
|
|
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
|
|
|
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
|
|
|
|
|
is_lora_b_loaded = True
|
|
|
|
|
elif peft_key.endswith(".lora_B.bias"):
|
|
|
|
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
|
|
|
|
|
|
|
if is_lora_a_loaded and is_lora_b_loaded:
|
|
|
|
|
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def estimate_adapter_memory_per_block(
|
|
|
|
|
block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
|
|
|
|
|
) -> int:
|
|
|
|
|
"""Get the number of extra bytes used to store a set of adapters per given block"""
|
|
|
|
|
with init_empty_weights(include_buffers=True):
|
|
|
|
|
block = block_config.block_class(block_config)
|
|
|
|
|
base_block_parameters = sum(p.numel() for p in block.parameters())
|
|
|
|
|
create_lora_adapter(block, quant_type=QuantType.NONE)
|
|
|
|
|
|
|
|
|
|
for adapter in adapters:
|
|
|
|
|
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
|
|
|
|
|
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
|
|
|
|
|
add_adapter_to_block(
|
|
|
|
|
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
|
|
|
|
)
|
|
|
|
|
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
|
|
|
|
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
|
|
|
|
|
return adapter_parameters * bytes_per_parameter
|
|
|
|
|