From aecf074f254ff0230678b864da08c90ebabb9b7d Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Mon, 8 Apr 2024 18:04:50 +0200 Subject: [PATCH] fix block init --- src/petals/server/block_utils.py | 4 ++-- src/petals/server/from_pretrained.py | 7 ++----- src/petals/server/throughput.py | 5 +++-- src/petals/utils/convert_block.py | 14 ++++++++++++++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index ac0995d..1320737 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -4,7 +4,7 @@ import torch from accelerate import init_empty_weights from transformers import PretrainedConfig -from petals.utils.convert_block import QuantType +from petals.utils.convert_block import QuantType, get_model_block from petals.utils.misc import get_size_in_bytes @@ -32,7 +32,7 @@ def get_block_size( ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations' with init_empty_weights(include_buffers=True): - block = config.block_class(config) + block = get_model_block(config) n_params = sum(param.numel() for param in block.parameters()) if location == "memory": diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 95cfbd8..61a296c 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -26,6 +26,7 @@ from petals.constants import DTYPE_MAP from petals.models.mixtral import WrappedMixtralBlock from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig +from petals.utils.convert_block import get_model_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.hf_auth import always_needs_auth @@ -52,11 +53,7 @@ def load_pretrained_block( torch_dtype = resolve_block_dtype(config, torch_dtype) with init_empty_weights(): - if config.block_class == WrappedMixtralBlock: - config = PreTrainedModel._autoset_attn_implementation(config) - block = config.block_class(config, block_index) - else: - block = config.block_class(config) + block = get_model_block(config, layer_idx=block_index) block_prefix = f"{config.block_prefix}.{block_index}." state_dict = _load_state_dict_from_repo( diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index c42bdb9..27af68f 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -14,7 +14,7 @@ from hivemind.utils.logging import get_logger from transformers import PretrainedConfig from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import QuantType, convert_block +from petals.utils.convert_block import QuantType, convert_block, get_model_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) @@ -201,7 +201,8 @@ def measure_compute_rps( if not tensor_parallel_devices: tensor_parallel_devices = (device,) with torch.inference_mode(): - block = config.block_class(config).to(dtype) + block = get_model_block(config) + block = block.to(dtype) block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = None diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29..db96ab8 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -12,6 +12,8 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig +from petals.models.mixtral import WrappedMixtralBlocks + use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -154,3 +156,15 @@ def check_device_balance(devices: Sequence[torch.device]): f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. " f"Consider running high-memory GPUs in a separate server." ) + + +def get_model_block(config, **kwargs): + """ + The function to create a model block based on the block class + kwargs argument **only** is necessary for specific classes, like Mixtral. + They will not be passed to other block constructors. + """ + if config.block_class == WrappedMixtralBlocks: + PreTrainedModel._autoset_attn_implementation(config) + return config.block_class(config, kwargs.get("layer_idx", 0)) + return config.block_class(config)