From de930918a0743da011caeccd7131f53278a1e8ae Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 3 Jul 2023 20:13:04 +0400 Subject: [PATCH] Support loading blocks in 4-bit (QLoRA NF4 format, disabled by default) (#333) --- setup.cfg | 2 +- src/petals/cli/run_server.py | 14 +++--- src/petals/server/block_utils.py | 27 ++++++----- src/petals/server/server.py | 30 ++++++------ src/petals/server/throughput.py | 24 +++++----- src/petals/utils/convert_block.py | 78 +++++++++++++++++-------------- tests/test_aux_functions.py | 3 +- tests/test_remote_sequential.py | 5 +- 8 files changed, 100 insertions(+), 83 deletions(-) diff --git a/setup.cfg b/setup.cfg index eacd99a..fb1fa23 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ packages = find: python_requires = >=3.7 install_requires = torch>=1.12 - bitsandbytes==0.38.0.post2 + bitsandbytes==0.39.1 accelerate>=0.16.0,<1.0.0 huggingface-hub>=0.11.1,<1.0.0 tokenizers>=0.13.3 diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 83e35e5..3c28709 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -8,6 +8,7 @@ from humanfriendly import parse_size from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS from petals.server.server import Server +from petals.utils.convert_block import QuantType from petals.utils.version import validate_version logger = get_logger(__name__) @@ -133,9 +134,10 @@ def main(): help="Check the swarm's balance every N seconds (and rebalance it if necessary)") parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained") - parser.add_argument('--load_in_8bit', type=str, default=None, - help="Convert the loaded transformer blocks into mixed-8bit quantized model. " - "Default: True if GPU is available. Use `--load_in_8bit False` to disable this") + parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType], + help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or " + "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. " + "Default: 'int8' if GPU is available, 'none' otherwise") parser.add_argument("--tensor_parallel_devices", nargs='+', default=None, help= "Split each block between the specified GPUs such that each device holds a portion of every " @@ -186,9 +188,9 @@ def main(): if args.pop("new_swarm"): args["initial_peers"] = [] - load_in_8bit = args.pop("load_in_8bit") - if load_in_8bit is not None: - args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"] + quant_type = args.pop("quant_type") + if quant_type is not None: + args["quant_type"] = QuantType[quant_type.upper()] validate_version() diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index a6af3b0..eb5300e 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -4,6 +4,8 @@ import torch from accelerate import init_empty_weights from transformers import PretrainedConfig +from petals.utils.convert_block import QuantType + def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" @@ -19,27 +21,30 @@ def get_block_size( location: str, *, dtype: Optional[Union[str, torch.dtype]] = None, - load_in_8bit: Optional[bool] = None, + quant_type: QuantType = QuantType.NONE, eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc. ) -> int: if location == "memory": assert ( - dtype is not None and load_in_8bit is not None - ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' + dtype is not None and quant_type is not None + ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' with init_empty_weights(include_buffers=True): block = config.block_class(config) n_params = sum(param.numel() for param in block.parameters()) - if location == "memory" and load_in_8bit: - # Note: We may need a larger eps here for models of size < 1B - return n_params * (1 + eps) - if location == "memory": - dtype = resolve_block_dtype(config, dtype) + if quant_type == QuantType.NONE: + dtype = resolve_block_dtype(config, dtype) + bytes_per_value = torch.finfo(dtype).bits // 8 + elif quant_type == QuantType.INT8: + bytes_per_value = 1 + elif quant_type == QuantType.NF4: + bytes_per_value = 4.25 / 8 # Bitness of NF4 with this config (measured empirically) + else: + raise ValueError(f"Unsupported quant_type={quant_type}") elif location == "disk": dtype = resolve_block_dtype(config, "auto") - else: - raise ValueError('get_block_size() expects location to be "memory" or "disk"') + bytes_per_value = torch.finfo(dtype).bits // 8 - return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps)) + return round(n_params * bytes_per_value * (1 + eps)) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 39c432c..2fbaad2 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -28,7 +28,7 @@ from petals.server.memory_cache import MemoryCache from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig -from petals.utils.convert_block import check_device_balance, convert_block +from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR from petals.utils.version import get_compatible_model_repo @@ -75,7 +75,7 @@ class Server: mean_balance_check_period: float = 120, mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, - load_in_8bit: Optional[bool] = None, + quant_type: Optional[QuantType] = None, tensor_parallel_devices: Optional[Sequence[torch.device]] = None, skip_reachability_check: bool = False, dht_client_mode: Optional[bool] = None, @@ -154,8 +154,8 @@ class Server: device = torch.device(device.type, index=0) self.device = device - torch_dtype = DTYPE_MAP[torch_dtype] - self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype) + torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype]) + self.torch_dtype = torch_dtype if tensor_parallel_devices is None: tensor_parallel_devices = (device,) @@ -164,10 +164,10 @@ class Server: logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}") check_device_balance(self.tensor_parallel_devices) - if load_in_8bit is None: - load_in_8bit = device.type == "cuda" - self.load_in_8bit = load_in_8bit - logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + if quant_type is None: + quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE + self.quant_type = quant_type + logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 @@ -203,7 +203,7 @@ class Server: device, torch_dtype, num_blocks=num_blocks, - load_in_8bit=load_in_8bit, + quant_type=quant_type, tensor_parallel_devices=self.tensor_parallel_devices, force_eval=(throughput == "eval"), cache_dir=cache_dir, @@ -237,11 +237,11 @@ class Server: else: total_memory = torch.cuda.get_device_properties(self.device).total_memory - block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) + block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type) - # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models gib = 1024**3 - autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward + # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models) + autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" @@ -284,7 +284,7 @@ class Server: sender_threads=self.sender_threads, revision=self.revision, use_auth_token=self.use_auth_token, - load_in_8bit=self.load_in_8bit, + quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, start=True, @@ -377,7 +377,7 @@ class ModuleContainer(threading.Thread): expiration: Optional[float], revision: Optional[str], use_auth_token: Optional[str], - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, **kwargs, @@ -411,7 +411,7 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) - block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True) + block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True) blocks[module_uid] = TransformerBackend( module_uid, block, diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 2ee1ca1..76bbc85 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger from transformers import PretrainedConfig from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import convert_block +from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) @@ -39,7 +39,7 @@ def get_server_throughput( dtype: Union[str, torch.dtype], *, num_blocks: int, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], force_eval: bool = False, cache_dir: Optional[str] = None, @@ -60,7 +60,7 @@ def get_server_throughput( cache_key = f"model_{model_name}" cache_key += f"_device_{get_device_name(device).replace(' ', '_')}" - cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}" + cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}" if len(tensor_parallel_devices) > 1: for i, device_i in enumerate(tensor_parallel_devices): cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}" @@ -77,7 +77,7 @@ def get_server_throughput( if cache_key not in cache: cache[cache_key] = measure_throughput_info( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices ) try: @@ -104,7 +104,7 @@ def measure_throughput_info( device: torch.device, dtype: torch.dtype, *, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], ) -> Dict[str, float]: """Measure network and compute throughput in forward pass tokens per second""" @@ -115,7 +115,7 @@ def measure_throughput_info( throughput_info = { "compute_rps": measure_compute_rps( - config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices + config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices ) } try: @@ -163,7 +163,7 @@ def measure_compute_rps( device: torch.device, dtype: torch.dtype, *, - load_in_8bit: bool, + quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], n_tokens: int = 16, n_steps: int = 500, @@ -172,7 +172,7 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = config.block_class(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) + block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = None elapsed = 0 @@ -192,7 +192,7 @@ def measure_compute_rps( logger.info( f"Forward pass throughput: {device_rps:.1f} RPS per block " - f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})" + f"({devices_repr}, {get_dtype_name(dtype, quant_type)})" ) return device_rps @@ -201,8 +201,8 @@ def get_device_name(device: torch.device) -> str: return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU" -def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str: +def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str: name = str(dtype) - if load_in_8bit: - name += ", 8-bit quantized" + if quant_type != QuantType.NONE: + name += f", quantized to {quant_type.name.lower()}" return name diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 28aea56..6b129f5 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -3,6 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par """ import os import re +from enum import Enum from typing import Sequence import tensor_parallel as tp @@ -16,13 +17,18 @@ use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + def convert_block( block: nn.Module, config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, - load_in_8bit: bool, - threshold: float = 6.0, + quant_type: QuantType, freeze: bool = True, ) -> tp.TensorParallel: """ @@ -34,20 +40,18 @@ def convert_block( :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity) :param output_device: if tensor_parallel_devices is True, output - :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint - :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 ) + :param quant_type: quantization type :param freeze: if True (default), make all module parameters non-trainable :return: a module that acts like the original block, but runs with all specified optimizations """ if freeze: - for param in block.parameters(): - param.requires_grad = False + block.requires_grad_(False) block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device) - if load_in_8bit: - block = replace_8bit_linear(block, threshold=threshold) + if quant_type != QuantType.NONE: + block = quantize_module(block, quant_type=quant_type) for shard, device in zip(block.module_shards, block.devices): shard.to(device) @@ -55,43 +59,45 @@ def convert_block( return block -def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module: - """ - A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes` - library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8(): - 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA - version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ - bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116) - The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should - be kept as a `torch.nn.Linear` module. - Parameters: - model (`torch.nn.Module`): - Input model or `torch.nn.Module` as the function is run recursively. - threshold (`float`, *optional*): - `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to - `6.0` as described by the paper. - """ - +def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module: # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes os.environ["BITSANDBYTES_NOWELCOME"] = "1" import bitsandbytes as bnb for n, module in model.named_children(): if len(list(module.children())) > 0: - replace_8bit_linear(module, threshold) + quantize_module(module, quant_type=quant_type) if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]: assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}" - model._modules[n] = bnb.nn.Linear8bitLt( - module.in_features, - module.out_features, - module.bias is not None, - has_fp16_weights=False, - threshold=threshold, - ) - model._modules[n].weight = bnb.nn.Int8Params( - module.weight.data, requires_grad=False, has_fp16_weights=False - ).to(module.weight.dtype) + if quant_type == QuantType.INT8: + model._modules[n] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=6.0, # Default from the LLM.int8() paper + ) + model._modules[n].weight = bnb.nn.Int8Params( + module.weight.data, requires_grad=False, has_fp16_weights=False + ).to(module.weight.dtype) + elif quant_type == QuantType.NF4: + compress_statistics = True + model._modules[n] = bnb.nn.LinearNF4( + module.in_features, + module.out_features, + module.bias is not None, + compress_statistics=compress_statistics, + ) + model._modules[n].weight = bnb.nn.Params4bit( + module.weight.data, + requires_grad=False, + quant_type="nf4", + blocksize=64, + compress_statistics=compress_statistics, + ).to(module.weight.dtype) + else: + raise ValueError(f"Unsupported quant_type='{quant_type}'") model._modules[n].bias = module.bias return model diff --git a/tests/test_aux_functions.py b/tests/test_aux_functions.py index d42666b..5fa14db 100644 --- a/tests/test_aux_functions.py +++ b/tests/test_aux_functions.py @@ -3,6 +3,7 @@ import torch from petals import AutoDistributedConfig from petals.server.throughput import measure_compute_rps +from petals.utils.convert_block import QuantType from test_utils import MODEL_NAME @@ -15,7 +16,7 @@ def test_compute_throughput(tensor_parallel: bool): config, device=torch.device("cpu"), dtype=torch.bfloat16, - load_in_8bit=False, + quant_type=QuantType.NONE, tensor_parallel_devices=tensor_parallel_devices, n_steps=10, ) diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 734683f..3c8a48f 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -78,7 +78,10 @@ class DummyCustomSequenceManager(RemoteSequenceManager): if protocol == "rpc_forward": metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) elif protocol == "rpc_backward": - metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,) + metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,) + # FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here. + # This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1. + # Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572 return metadata