Test that bitsandbytes is not imported when it's not used (#351)

We avoid importing bitsandbytes when it's not used, since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
pull/354/head
Alexander Borzunov 10 months ago committed by GitHub
parent c511990236
commit 1a78638c02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -33,7 +33,7 @@ python_requires = >=3.7
install_requires =
torch>=1.12
bitsandbytes==0.40.0.post4
accelerate>=0.16.0,<1.0.0
accelerate>=0.16.0,<0.21.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.30.1,<5.0.0

@ -30,7 +30,6 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
from petals.utils.peft import estimate_adapter_memory_per_block
from petals.utils.version import get_compatible_model_repo
logger = get_logger(__name__)

@ -1,3 +1,6 @@
import subprocess
import sys
import pytest
import torch
@ -7,6 +10,19 @@ from petals.utils.convert_block import QuantType
from test_utils import MODEL_NAME
def test_bnb_not_imported_when_unnecessary():
"""
We avoid importing bitsandbytes when it's not used,
since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft
in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.
This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.
"""
subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"])
@pytest.mark.forked
@pytest.mark.parametrize("tensor_parallel", [False, True])
def test_compute_throughput(tensor_parallel: bool):

@ -25,7 +25,7 @@ def test_sequence_manager_basics(mode: str):
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
sequential = RemoteSequential(
config,
sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
)
sequence = sequential.sequence_manager.make_sequence(mode=mode)
@ -43,7 +43,7 @@ def test_sequence_manager_basics(mode: str):
assert shutdown_evt.is_set()
class TestSequenceManager(RemoteSequenceManager):
class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
"""A sequence manager that signals if it was shut down"""
def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):

Loading…
Cancel
Save