Remove load_in_8bit from convert_block

pull/273/head
Max Ryabinin 1 year ago
parent 556f0fabe0
commit a610f4d744

@ -13,7 +13,7 @@ from transformers import BloomConfig
from petals.bloom.block import WrappedBloomBlock
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import convert_block
from petals.utils.convert_block import convert_block, replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__name__)
@ -149,7 +149,9 @@ def measure_compute_rps(
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = WrappedBloomBlock(config).to(dtype)
block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
if load_in_8bit:
block = replace_8bit_linear(block)
block = convert_block(block, config, tensor_parallel_devices, device, freeze=True)
cache = None
elapsed = 0

Loading…
Cancel
Save