Support loading blocks in 4-bit (QLoRA NF4 format, disabled by default) (#333)

pull/339/head
Alexander Borzunov 11 months ago committed by GitHub
parent 66a47c763e
commit de930918a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,7 +32,7 @@ packages = find:
python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.38.0.post2
bitsandbytes==0.39.1
accelerate>=0.16.0,<1.0.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3

@ -8,6 +8,7 @@ from humanfriendly import parse_size
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.utils.convert_block import QuantType
from petals.utils.version import validate_version
logger = get_logger(__name__)
@ -133,9 +134,10 @@ def main():
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', type=str, default=None,
help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
"Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
"4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
"Default: 'int8' if GPU is available, 'none' otherwise")
parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
help=
"Split each block between the specified GPUs such that each device holds a portion of every "
@ -186,9 +188,9 @@ def main():
if args.pop("new_swarm"):
args["initial_peers"] = []
load_in_8bit = args.pop("load_in_8bit")
if load_in_8bit is not None:
args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
quant_type = args.pop("quant_type")
if quant_type is not None:
args["quant_type"] = QuantType[quant_type.upper()]
validate_version()

@ -4,6 +4,8 @@ import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from petals.utils.convert_block import QuantType
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
@ -19,27 +21,30 @@ def get_block_size(
location: str,
*,
dtype: Optional[Union[str, torch.dtype]] = None,
load_in_8bit: Optional[bool] = None,
quant_type: QuantType = QuantType.NONE,
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
) -> int:
if location == "memory":
assert (
dtype is not None and load_in_8bit is not None
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
dtype is not None and quant_type is not None
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
block = config.block_class(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory" and load_in_8bit:
# Note: We may need a larger eps here for models of size < 1B
return n_params * (1 + eps)
if location == "memory":
dtype = resolve_block_dtype(config, dtype)
if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype)
bytes_per_value = torch.finfo(dtype).bits // 8
elif quant_type == QuantType.INT8:
bytes_per_value = 1
elif quant_type == QuantType.NF4:
bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically)
else:
raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
else:
raise ValueError('get_block_size() expects location to be "memory" or "disk"')
bytes_per_value = torch.finfo(dtype).bits // 8
return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))
return round(n_params * bytes_per_value * (1 + eps))

@ -28,7 +28,7 @@ from petals.server.memory_cache import MemoryCache
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import check_device_balance, convert_block
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.version import get_compatible_model_repo
@ -75,7 +75,7 @@ class Server:
mean_balance_check_period: float = 120,
mean_block_selection_delay: float = 2.5,
use_auth_token: Optional[str] = None,
load_in_8bit: Optional[bool] = None,
quant_type: Optional[QuantType] = None,
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
skip_reachability_check: bool = False,
dht_client_mode: Optional[bool] = None,
@ -154,8 +154,8 @@ class Server:
device = torch.device(device.type, index=0)
self.device = device
torch_dtype = DTYPE_MAP[torch_dtype]
self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
self.torch_dtype = torch_dtype
if tensor_parallel_devices is None:
tensor_parallel_devices = (device,)
@ -164,10 +164,10 @@ class Server:
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
check_device_balance(self.tensor_parallel_devices)
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
if quant_type is None:
quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
@ -203,7 +203,7 @@ class Server:
device,
torch_dtype,
num_blocks=num_blocks,
load_in_8bit=load_in_8bit,
quant_type=quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
force_eval=(throughput == "eval"),
cache_dir=cache_dir,
@ -237,11 +237,11 @@ class Server:
else:
total_memory = torch.cuda.get_device_properties(self.device).total_memory
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
# The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
gib = 1024**3
autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
@ -284,7 +284,7 @@ class Server:
sender_threads=self.sender_threads,
revision=self.revision,
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
quant_type=self.quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
should_validate_reachability=self.should_validate_reachability,
start=True,
@ -377,7 +377,7 @@ class ModuleContainer(threading.Thread):
expiration: Optional[float],
revision: Optional[str],
use_auth_token: Optional[str],
load_in_8bit: bool,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
should_validate_reachability: bool,
**kwargs,
@ -411,7 +411,7 @@ class ModuleContainer(threading.Thread):
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
blocks[module_uid] = TransformerBackend(
module_uid,
block,

@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import convert_block
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__name__)
@ -39,7 +39,7 @@ def get_server_throughput(
dtype: Union[str, torch.dtype],
*,
num_blocks: int,
load_in_8bit: bool,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
force_eval: bool = False,
cache_dir: Optional[str] = None,
@ -60,7 +60,7 @@ def get_server_throughput(
cache_key = f"model_{model_name}"
cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}"
if len(tensor_parallel_devices) > 1:
for i, device_i in enumerate(tensor_parallel_devices):
cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
@ -77,7 +77,7 @@ def get_server_throughput(
if cache_key not in cache:
cache[cache_key] = measure_throughput_info(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
)
try:
@ -104,7 +104,7 @@ def measure_throughput_info(
device: torch.device,
dtype: torch.dtype,
*,
load_in_8bit: bool,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
) -> Dict[str, float]:
"""Measure network and compute throughput in forward pass tokens per second"""
@ -115,7 +115,7 @@ def measure_throughput_info(
throughput_info = {
"compute_rps": measure_compute_rps(
config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
)
}
try:
@ -163,7 +163,7 @@ def measure_compute_rps(
device: torch.device,
dtype: torch.dtype,
*,
load_in_8bit: bool,
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
n_tokens: int = 16,
n_steps: int = 500,
@ -172,7 +172,7 @@ def measure_compute_rps(
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = None
elapsed = 0
@ -192,7 +192,7 @@ def measure_compute_rps(
logger.info(
f"Forward pass throughput: {device_rps:.1f} RPS per block "
f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
f"({devices_repr}, {get_dtype_name(dtype, quant_type)})"
)
return device_rps
@ -201,8 +201,8 @@ def get_device_name(device: torch.device) -> str:
return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
name = str(dtype)
if load_in_8bit:
name += ", 8-bit quantized"
if quant_type != QuantType.NONE:
name += f", quantized to {quant_type.name.lower()}"
return name

@ -3,6 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
"""
import os
import re
from enum import Enum
from typing import Sequence
import tensor_parallel as tp
@ -16,13 +17,18 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
def convert_block(
block: nn.Module,
config: PretrainedConfig,
tensor_parallel_devices: Sequence[torch.device],
output_device: torch.device,
load_in_8bit: bool,
threshold: float = 6.0,
quant_type: QuantType,
freeze: bool = True,
) -> tp.TensorParallel:
"""
@ -34,20 +40,18 @@ def convert_block(
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
:param output_device: if tensor_parallel_devices is True, output
:param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
:param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
:param quant_type: quantization type
:param freeze: if True (default), make all module parameters non-trainable
:return: a module that acts like the original block, but runs with all specified optimizations
"""
if freeze:
for param in block.parameters():
param.requires_grad = False
block.requires_grad_(False)
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
if load_in_8bit:
block = replace_8bit_linear(block, threshold=threshold)
if quant_type != QuantType.NONE:
block = quantize_module(block, quant_type=quant_type)
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
@ -55,43 +59,45 @@ def convert_block(
return block
def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module:
"""
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
be kept as a `torch.nn.Linear` module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
threshold (`float`, *optional*):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
import bitsandbytes as bnb
for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
quantize_module(module, quant_type=quant_type)
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=threshold,
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
if quant_type == QuantType.INT8:
model._modules[n] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=False,
threshold=6.0, # Default from the LLM.int8() paper
)
model._modules[n].weight = bnb.nn.Int8Params(
module.weight.data, requires_grad=False, has_fp16_weights=False
).to(module.weight.dtype)
elif quant_type == QuantType.NF4:
compress_statistics = True
model._modules[n] = bnb.nn.LinearNF4(
module.in_features,
module.out_features,
module.bias is not None,
compress_statistics=compress_statistics,
)
model._modules[n].weight = bnb.nn.Params4bit(
module.weight.data,
requires_grad=False,
quant_type="nf4",
blocksize=64,
compress_statistics=compress_statistics,
).to(module.weight.dtype)
else:
raise ValueError(f"Unsupported quant_type='{quant_type}'")
model._modules[n].bias = module.bias
return model

@ -3,6 +3,7 @@ import torch
from petals import AutoDistributedConfig
from petals.server.throughput import measure_compute_rps
from petals.utils.convert_block import QuantType
from test_utils import MODEL_NAME
@ -15,7 +16,7 @@ def test_compute_throughput(tensor_parallel: bool):
config,
device=torch.device("cpu"),
dtype=torch.bfloat16,
load_in_8bit=False,
quant_type=QuantType.NONE,
tensor_parallel_devices=tensor_parallel_devices,
n_steps=10,
)

@ -78,7 +78,10 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
if protocol == "rpc_forward":
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
elif protocol == "rpc_backward":
metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
# FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here.
# This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1.
# Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572
return metadata

Loading…
Cancel
Save