You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/server/block_utils.py

66 lines
2.5 KiB
Python

from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig, PreTrainedModel
from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype not in ("auto", None):
return dtype
if config.torch_dtype not in ("auto", None, torch.float32):
# If config specifies float32, we override it to the default dtype below
return config.torch_dtype
return torch.bfloat16
def get_block_size(
Add LLaMA support (#323) This PR: 1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms. - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ. - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name). 2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`. 3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.). 4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers. Upgrade instructions: - Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
11 months ago
config: PretrainedConfig,
location: str,
*,
dtype: Optional[Union[str, torch.dtype]] = None,
quant_type: QuantType = QuantType.NONE,
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
) -> int:
if location == "memory":
assert (
dtype is not None and quant_type is not None
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory":
if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype)
bytes_per_value = get_size_in_bytes(dtype)
elif quant_type == QuantType.INT8:
bytes_per_value = 1
elif quant_type == QuantType.NF4:
bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically)
else:
raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config)