diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index 1320737..8c12d51 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -2,9 +2,10 @@ from typing import Optional, Union import torch from accelerate import init_empty_weights -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedModel -from petals.utils.convert_block import QuantType, get_model_block +from petals.models.mixtral.block import WrappedMixtralBlock +from petals.utils.convert_block import QuantType from petals.utils.misc import get_size_in_bytes @@ -50,3 +51,15 @@ def get_block_size( bytes_per_value = get_size_in_bytes(dtype) return round(n_params * bytes_per_value * (1 + eps)) + + +def get_model_block(config, **kwargs): + """ + The function to create a model block based on the block class + kwargs argument **only** is necessary for specific classes, like Mixtral. + They will not be passed to other block constructors. + """ + if config.block_class == WrappedMixtralBlock: + PreTrainedModel._autoset_attn_implementation(config) + return config.block_class(config, kwargs.get("layer_idx", 0)) + return config.block_class(config) diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 61a296c..e7c1658 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -24,9 +24,8 @@ from transformers.utils import get_file_from_repo from petals.constants import DTYPE_MAP from petals.models.mixtral import WrappedMixtralBlock -from petals.server.block_utils import resolve_block_dtype +from petals.server.block_utils import resolve_block_dtype, get_model_block from petals.utils.auto_config import AutoDistributedConfig -from petals.utils.convert_block import get_model_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.hf_auth import always_needs_auth diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 27af68f..eab392f 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,8 +13,8 @@ import torch.mps from hivemind.utils.logging import get_logger from transformers import PretrainedConfig -from petals.server.block_utils import resolve_block_dtype -from petals.utils.convert_block import QuantType, convert_block, get_model_block +from petals.server.block_utils import resolve_block_dtype, get_model_block +from petals.utils.convert_block import QuantType, convert_block from petals.utils.disk_cache import DEFAULT_CACHE_DIR logger = get_logger(__name__) diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index db96ab8..94d3e29 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -12,8 +12,6 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig -from petals.models.mixtral import WrappedMixtralBlocks - use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) @@ -156,15 +154,3 @@ def check_device_balance(devices: Sequence[torch.device]): f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. " f"Consider running high-memory GPUs in a separate server." ) - - -def get_model_block(config, **kwargs): - """ - The function to create a model block based on the block class - kwargs argument **only** is necessary for specific classes, like Mixtral. - They will not be passed to other block constructors. - """ - if config.block_class == WrappedMixtralBlocks: - PreTrainedModel._autoset_attn_implementation(config) - return config.block_class(config, kwargs.get("layer_idx", 0)) - return config.block_class(config)