Add local tensor-parallel fwd/bwd (#143)
This pull request adds an option to run Petals server on multiple local GPUs. It uses https://github.com/BlackSamorez/tensor_parallel - 8bit approximation error same as in main (mean~=2% q0.9~=5%) - TP=1, 2, 3 (see screenshots above) - forward, grad w.r.t. input and inference exact match with main with TP=1 - `>=`80% GPU utilization with 3x 1080ti, batch = 8 tokens - throughput measured with and without TP - TP on 1080Tis has near-linear speedup comparable to the benchmarks (see first message) Co-authored-by: Iaroslav Lisniak <yalisnyak@nes.ru> Co-authored-by: Andrei Panferov <andrei@blacksamorez.ru> Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>pull/175/head
parent
779959bc70
commit
ae9e71fe8e
@ -1,39 +0,0 @@
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
||||
|
||||
|
||||
def replace_8bit_linear(model, threshold=6.0):
|
||||
"""
|
||||
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
||||
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
|
||||
be kept as a `torch.nn.Linear` module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
threshold (`float`, *optional*):
|
||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
||||
`6.0` as described by the paper.
|
||||
"""
|
||||
for n, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_8bit_linear(module, threshold)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||
model._modules[n] = CustomLinear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=threshold,
|
||||
)
|
||||
model._modules[n].weight = bnb.nn.Int8Params(
|
||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||
).to(module.weight.dtype)
|
||||
model._modules[n].bias = module.bias
|
||||
return model
|
@ -0,0 +1,132 @@
|
||||
"""
|
||||
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
|
||||
"""
|
||||
import re
|
||||
from typing import Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from tensor_parallel.slicing_configs import get_bloom_config
|
||||
from transformers import BloomConfig
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
from petals.bloom.block import WrappedBloomBlock
|
||||
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def convert_block(
|
||||
block: WrappedBloomBlock,
|
||||
config: BloomConfig,
|
||||
tensor_parallel_devices: Sequence[torch.device],
|
||||
output_device: torch.device,
|
||||
load_in_8bit: bool,
|
||||
threshold: float = 6.0,
|
||||
freeze: bool = True,
|
||||
) -> tp.TensorParallel:
|
||||
"""
|
||||
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
||||
|
||||
:note: some optimizations will modify the input block in-place!
|
||||
:param block: a single transformer block, either pre-trained or newly initialized
|
||||
:param config: HF transformers config for the full model
|
||||
:param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
|
||||
:note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
|
||||
:param output_device: if tensor_parallel_devices is True, output
|
||||
:param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
|
||||
:param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
|
||||
:param freeze: if True (default), make all module parameters non-trainable
|
||||
:return: a module that acts like the original block, but runs with all specified optimizations
|
||||
|
||||
"""
|
||||
if freeze:
|
||||
for param in block.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
|
||||
|
||||
if load_in_8bit:
|
||||
block = replace_8bit_linear(block, threshold=threshold)
|
||||
|
||||
for shard, device in zip(block.module_shards, block.devices):
|
||||
shard.to(device)
|
||||
|
||||
return block
|
||||
|
||||
|
||||
def replace_8bit_linear(model: nn.Module, threshold=6.0):
|
||||
"""
|
||||
A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
|
||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
||||
bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
|
||||
be kept as a `torch.nn.Linear` module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
threshold (`float`, *optional*):
|
||||
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
||||
`6.0` as described by the paper.
|
||||
"""
|
||||
for n, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_8bit_linear(module, threshold)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
|
||||
assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
|
||||
model._modules[n] = CustomLinear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=threshold,
|
||||
)
|
||||
model._modules[n].weight = bnb.nn.Int8Params(
|
||||
module.weight.data, requires_grad=False, has_fp16_weights=False
|
||||
).to(module.weight.dtype)
|
||||
model._modules[n].bias = module.bias
|
||||
return model
|
||||
|
||||
|
||||
def make_tensor_parallel(
|
||||
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
|
||||
):
|
||||
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
|
||||
tp_config = get_bloom_config(model_config, devices)
|
||||
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
||||
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
||||
total_heads = 0
|
||||
for tp_shard in tp_block.module_shards:
|
||||
for submodule in tp_shard.modules():
|
||||
if isinstance(submodule, BloomAttention):
|
||||
total_heads += submodule.num_heads
|
||||
assert total_heads == model_config.n_head
|
||||
return tp_block
|
||||
|
||||
|
||||
def check_device_balance(devices: Sequence[torch.device]):
|
||||
if not all(device.type == "cuda" for device in devices):
|
||||
logger.warning("Running tensor parallelism on non-GPU devices; proceed at your own risk")
|
||||
return
|
||||
unique_device_capabilities = set(map(torch.cuda.get_device_capability, devices))
|
||||
if len(unique_device_capabilities) > 1:
|
||||
logger.warning(
|
||||
f"Found GPUs with uneven capabilities: {unique_device_capabilities}. "
|
||||
f"Using GPUs with different performance will cause the server to wait for the slowest GPU."
|
||||
)
|
||||
|
||||
memory_per_device = tuple(torch.cuda.get_device_properties(device).total_memory for device in devices)
|
||||
used_memory = min(memory_per_device) * len(memory_per_device)
|
||||
wasted_memory_rate = (sum(memory_per_device) - used_memory) / sum(memory_per_device)
|
||||
if wasted_memory_rate > 0.05:
|
||||
logger.warning(
|
||||
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
||||
f"Consider running high-memory GPUs in a separate server."
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
from tensor_parallel import TensorParallel
|
||||
from tensor_parallel.slicing_configs import get_bloom_config
|
||||
from test_utils import MODEL_NAME
|
||||
|
||||
from petals.bloom.from_pretrained import load_pretrained_block
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("custom_config", [True, False])
|
||||
@pytest.mark.parametrize("devices", [("cpu",) * 2, ("cpu",) * 3, ("cpu",) * 4])
|
||||
def test_tp_block(devices, custom_config):
|
||||
block_index = random.randint(0, 10)
|
||||
model_config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
||||
block = load_pretrained_block(MODEL_NAME, block_index=block_index, torch_dtype=torch.float32).to(devices[0])
|
||||
|
||||
tp_config = None
|
||||
if custom_config:
|
||||
tp_config = get_bloom_config(model_config, devices)
|
||||
|
||||
batch_size = 2
|
||||
prefix_length = 5
|
||||
|
||||
test_inputs1 = torch.randn(batch_size, 3, 1024, requires_grad=True, device=devices[0])
|
||||
test_inputs2 = test_inputs1.detach().clone().requires_grad_(True)
|
||||
test_prefix1 = torch.randn(batch_size, prefix_length, 1024, requires_grad=True, device=devices[0])
|
||||
test_prefix2 = test_prefix1.detach().clone().requires_grad_(True)
|
||||
grad_proj = torch.rand_like(test_inputs1)
|
||||
|
||||
y_prefix_ref, layer_past = block(test_prefix1, use_cache=True)
|
||||
y_ref, cache_ref = block(test_inputs1, use_cache=True, layer_past=layer_past)
|
||||
y_ref.backward(grad_proj)
|
||||
|
||||
block_tp = TensorParallel(block, devices, config=tp_config)
|
||||
y_prefix, layer_past = block_tp(test_prefix2, use_cache=True)
|
||||
y_ours, cache_ours = block_tp(test_inputs2, use_cache=True, layer_past=layer_past)
|
||||
y_ours.backward(grad_proj)
|
||||
|
||||
assert torch.allclose(y_prefix, y_prefix_ref, atol=1e-6)
|
||||
assert torch.allclose(y_ours, y_ref, atol=1e-6)
|
||||
assert torch.allclose(test_inputs1.grad, test_inputs2.grad, atol=1e-4)
|
||||
assert torch.allclose(test_prefix1.grad, test_prefix2.grad, atol=1e-4)
|
Loading…
Reference in New Issue