|
|
|
@ -4,7 +4,6 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
|
|
|
|
|
import re
|
|
|
|
|
from typing import Sequence
|
|
|
|
|
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
import tensor_parallel as tp
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
@ -14,7 +13,6 @@ from transformers import BloomConfig
|
|
|
|
|
from transformers.models.bloom.modeling_bloom import BloomAttention
|
|
|
|
|
|
|
|
|
|
from petals.bloom.block import WrappedBloomBlock
|
|
|
|
|
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
|
|
|
|
|
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
|
|
|
|
|
`6.0` as described by the paper.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
|
|
|
|
|
|
|
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
|
|
|
|
|
|
|
|
|
|
for n, module in model.named_children():
|
|
|
|
|
if len(list(module.children())) > 0:
|
|
|
|
|
replace_8bit_linear(module, threshold)
|
|
|
|
@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
|
|
|
|
|
def make_tensor_parallel(
|
|
|
|
|
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
|
|
|
|
|
tp_config = get_bloom_config(model_config, devices)
|
|
|
|
|
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
|
|
|
|
|
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
|
|
|
|
|