|
|
|
@ -13,7 +13,7 @@ from transformers import BloomConfig
|
|
|
|
|
|
|
|
|
|
from petals.bloom.block import WrappedBloomBlock
|
|
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
|
|
from petals.utils.convert_block import convert_block
|
|
|
|
|
from petals.utils.convert_block import convert_block, replace_8bit_linear
|
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
@ -149,7 +149,9 @@ def measure_compute_rps(
|
|
|
|
|
tensor_parallel_devices = (device,)
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
block = WrappedBloomBlock(config).to(dtype)
|
|
|
|
|
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
|
|
|
|
|
if load_in_8bit:
|
|
|
|
|
block = replace_8bit_linear(block)
|
|
|
|
|
block = convert_block(block, config, tensor_parallel_devices, device, freeze=True)
|
|
|
|
|
|
|
|
|
|
cache = None
|
|
|
|
|
elapsed = 0
|
|
|
|
|