fix block init

pull/570/head
Artem Chumachenko 1 month ago
parent d2fcbbc72e
commit aecf074f25

@ -4,7 +4,7 @@ import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from petals.utils.convert_block import QuantType
from petals.utils.convert_block import QuantType, get_model_block
from petals.utils.misc import get_size_in_bytes
@ -32,7 +32,7 @@ def get_block_size(
), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
with init_empty_weights(include_buffers=True):
block = config.block_class(config)
block = get_model_block(config)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory":

@ -26,6 +26,7 @@ from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import get_model_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
@ -52,11 +53,7 @@ def load_pretrained_block(
torch_dtype = resolve_block_dtype(config, torch_dtype)
with init_empty_weights():
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
block = config.block_class(config, block_index)
else:
block = config.block_class(config)
block = get_model_block(config, layer_idx=block_index)
block_prefix = f"{config.block_prefix}.{block_index}."
state_dict = _load_state_dict_from_repo(

@ -14,7 +14,7 @@ from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.convert_block import QuantType, convert_block, get_model_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__name__)
@ -201,7 +201,8 @@ def measure_compute_rps(
if not tensor_parallel_devices:
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = get_model_block(config)
block = block.to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = None

@ -12,6 +12,8 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
from petals.models.mixtral import WrappedMixtralBlocks
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
@ -154,3 +156,15 @@ def check_device_balance(devices: Sequence[torch.device]):
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
f"Consider running high-memory GPUs in a separate server."
)
def get_model_block(config, **kwargs):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlocks:
PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config)

Loading…
Cancel
Save