Estimate adapter memory overhead in choose_num_blocks() (#346)

* estimate adapter memory overhead
* reduce number of heads based on that

---------

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
pull/350/head
justheuristic 10 months ago committed by GitHub
parent f605f093f7
commit 010857a834
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,6 +30,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.peft import estimate_adapter_memory_per_block
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)
@ -176,6 +177,8 @@ class Server:
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
self.cache_dir = cache_dir
self.adapters = adapters
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
@ -197,7 +200,6 @@ class Server:
self.alloc_timeout = alloc_timeout
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
@ -219,8 +221,6 @@ class Server:
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.adapters = adapters
self.stop = threading.Event()
def _choose_num_blocks(self) -> int:
@ -250,7 +250,16 @@ class Server:
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
if adapters:
# Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
from petals.utils.peft import estimate_adapter_memory_per_block
adapter_memory_per_block = estimate_adapter_memory_per_block(
self.block_config, self.torch_dtype, self.adapters, self.cache_dir
)
total_memory_per_block = block_size + adapter_memory_per_block + self._cache_bytes_per_block
num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
num_blocks = min(num_blocks, self.block_config.num_hidden_layers)

@ -55,8 +55,6 @@ def convert_block(
shard.to(device)
if adapters:
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type)

@ -1,9 +1,16 @@
import os
import re
import time
from typing import List, Optional
from typing import List, Optional, Sequence
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
import bitsandbytes as bnb
import peft
import torch
import torch.nn as nn
import transformers
from accelerate import init_empty_weights
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.tuners import lora
@ -12,6 +19,8 @@ from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.client.ptune import force_non_empty_weights
from petals.server.block_utils import resolve_block_dtype
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType
@ -194,15 +203,35 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
p.requires_grad = False
if peft_key.endswith(".lora_A.weight"):
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
child.lora_A[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_a_loaded = True
elif peft_key.endswith(".lora_A.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
elif peft_key.endswith(".lora_B.weight"):
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
child.lora_B[adapter_name].weight[...] = peft_state_dict[peft_key]
is_lora_b_loaded = True
elif peft_key.endswith(".lora_B.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
if is_lora_a_loaded and is_lora_b_loaded:
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype], adapters: Sequence[str], **kwargs
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)
for adapter in adapters:
peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **kwargs)
assert peft_config["peft_type"].upper() == "LORA", "only LoRA adapters are supported for now"
add_adapter_to_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
return adapter_parameters * bytes_per_parameter

Loading…
Cancel
Save