diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index ac43759..2b9c0fe 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,7 +13,7 @@ from transformers import BloomConfig from petals.bloom.block import WrappedBloomBlock from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import convert_block +from petals.utils.convert_block import convert_block, replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) @@ -149,7 +149,9 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = WrappedBloomBlock(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True) + if load_in_8bit: + block = replace_8bit_linear(block) + block = convert_block(block, config, tensor_parallel_devices, device, freeze=True) cache = None elapsed = 0