Import bitsandbytes only if it's going to be used (#180)

pull/181/head
Alexander Borzunov 1 year ago committed by GitHub
parent e27706358c
commit 6dd9a938bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,6 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
import re
from typing import Sequence
import bitsandbytes as bnb
import tensor_parallel as tp
import torch
import torch.nn as nn
@ -14,7 +13,6 @@ from transformers import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomAttention
from petals.bloom.block import WrappedBloomBlock
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -75,6 +73,12 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
`int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
`6.0` as described by the paper.
"""
# Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
import bitsandbytes as bnb
from petals.utils.linear8bitlt_patch import CustomLinear8bitLt
for n, module in model.named_children():
if len(list(module.children())) > 0:
replace_8bit_linear(module, threshold)
@ -98,7 +102,6 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
def make_tensor_parallel(
block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
):
assert isinstance(block, (WrappedBloomBlock, CustomLinear8bitLt))
tp_config = get_bloom_config(model_config, devices)
del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)

Loading…
Cancel
Save