|
|
|
"""
|
|
|
|
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
|
|
|
|
"""
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
from typing import List, Optional, Sequence
|
|
|
|
|
|
|
|
import tensor_parallel as tp
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
from tensor_parallel.slicing_configs import get_bloom_config
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
from petals.utils.misc import QuantType
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def convert_block(
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
block: nn.Module,
|
|
|
|
block_index: int,
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
config: PretrainedConfig,
|
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
|
output_device: torch.device,
|
|
|
|
quant_type: QuantType,
|
|
|
|
freeze: bool = True,
|
|
|
|
adapters: Optional[List[str]] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> tp.TensorParallel:
|
|
|
|
"""
|
|
|
|
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
|
|
|
|
|
|
|
:note: some optimizations will modify the input block in-place!
|
|
|
|
:param block: a single transformer block, either pre-trained or newly initialized
|
|
|
|
:param config: HF transformers config for the full model
|
|
|
|
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
|
|
|
|
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
|
|
|
|
:param output_device: if tensor_parallel_devices is True, output
|
|
|
|
:param quant_type: quantization type
|
|
|
|
:param freeze: if True (default), make all module parameters non-trainable
|
|
|
|
:return: a module that acts like the original block, but runs with all specified optimizations
|
|
|
|
|
|
|
|
"""
|
|
|
|
if freeze:
|
|
|
|
block.requires_grad_(False)
|
|
|
|
|
|
|
|
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
|
|
|
|
|
|
|
|
if quant_type != QuantType.NONE:
|
|
|
|
block = quantize_module(block, quant_type=quant_type)
|
|
|
|
|
|
|
|
for shard, device in zip(block.module_shards, block.devices):
|
|
|
|
shard.to(device)
|
|
|
|
|
|
|
|
if adapters:
|
|
|
|
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
|
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
|
|
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
|
|
|
|
|
|
|
|
create_lora_adapter(block, quant_type=quant_type)
|
|
|
|
for adapter_name in adapters:
|
|
|
|
adapter_config, adapter_state_dict = load_peft(
|
|
|
|
adapter_name,
|
|
|
|
block_idx=block_index,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
|
|
|
|
|
|
|
|
return block
|
|
|
|
|
|
|
|
|
|
|
|
def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
|
|
|
|
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
|
|
|
|
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
|
|
|
for n, module in model.named_children():
|
|
|
|
if len(list(module.children())) > 0:
|
|
|
|
quantize_module(module, quant_type=quant_type)
|
|
|
|
|
|
|
|
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
|
|
|
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
|
|
|
|
if quant_type == QuantType.INT8:
|
|
|
|
model._modules[n] = bnb.nn.Linear8bitLt(
|
|
|
|
module.in_features,
|
|
|
|
module.out_features,
|
|
|
|
module.bias is not None,
|
|
|
|
has_fp16_weights=False,
|
|
|
|
threshold=6.0, # Default from the LLM.int8() paper
|
|
|
|
)
|
|
|
|
model._modules[n].weight = bnb.nn.Int8Params(
|
|
|
|
module.weight.data, requires_grad=False, has_fp16_weights=False
|
|
|
|
).to(module.weight.dtype)
|
|
|
|
elif quant_type == QuantType.NF4:
|
|
|
|
compress_statistics = True
|
|
|
|
model._modules[n] = bnb.nn.LinearNF4(
|
|
|
|
module.in_features,
|
|
|
|
module.out_features,
|
|
|
|
module.bias is not None,
|
|
|
|
compress_statistics=compress_statistics,
|
|
|
|
)
|
|
|
|
model._modules[n].weight = bnb.nn.Params4bit(
|
|
|
|
module.weight.data,
|
|
|
|
requires_grad=False,
|
|
|
|
quant_type="nf4",
|
|
|
|
blocksize=64,
|
|
|
|
compress_statistics=compress_statistics,
|
|
|
|
).to(module.weight.dtype)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unsupported quant_type='{quant_type}'")
|
|
|
|
model._modules[n].bias = module.bias
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def make_tensor_parallel(
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
|
|
|
|
) -> nn.Module:
|
|
|
|
if model_config.model_type == "bloom":
|
|
|
|
tp_config = get_bloom_config(model_config, devices)
|
|
|
|
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
|
|
|
else:
|
|
|
|
if len(devices) > 1:
|
|
|
|
logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
|
|
|
|
tp_config = None
|
|
|
|
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
|
|
|
total_heads = 0
|
|
|
|
for tp_shard in tp_block.module_shards:
|
|
|
|
for submodule in tp_shard.modules():
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
if isinstance(submodule, model_config.attn_class):
|
|
|
|
total_heads += submodule.num_heads
|
Add LLaMA support (#323)
This PR:
1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.
- BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
- LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).
2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.
3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).
4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.
Upgrade instructions:
- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
1 year ago
|
|
|
assert total_heads == model_config.num_attention_heads
|
|
|
|
return tp_block
|
|
|
|
|
|
|
|
|
|
|
|
def check_device_balance(devices: Sequence[torch.device]):
|
|
|
|
if not all(device.type == "cuda" for device in devices):
|
|
|
|
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
|
|
|
|
return
|
|
|
|
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
|
|
|
|
if len(unique_device_capabilities) > 1:
|
|
|
|
logger.warning(
|
|
|
|
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
|
|
|
|
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
|
|
|
|
)
|
|
|
|
|
|
|
|
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
|
|
|
|
used_memory = min(memory_per_device) * len(memory_per_device)
|
|
|
|
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
|
|
|
|
if wasted_memory_rate > 0.05:
|
|
|
|
logger.warning(
|
|
|
|
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
|
|
|
f"Consider running high-memory GPUs in a separate server."
|
|
|
|
)
|