fix imports

pull/570/head
Artem Chumachenko 1 month ago
parent aecf074f25
commit f06cfd2b97

@ -2,9 +2,10 @@ from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from transformers import PretrainedConfig
from transformers import PretrainedConfig, PreTrainedModel
from petals.utils.convert_block import QuantType, get_model_block
from petals.models.mixtral.block import WrappedMixtralBlock
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
@ -50,3 +51,15 @@ def get_block_size(
bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, **kwargs):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlock:
PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config)

@ -24,9 +24,8 @@ from transformers.utils import get_file_from_repo
from petals.constants import DTYPE_MAP
from petals.models.mixtral import WrappedMixtralBlock
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import resolve_block_dtype, get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import get_model_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth

@ -13,8 +13,8 @@ import torch.mps
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType, convert_block, get_model_block
from petals.server.block_utils import resolve_block_dtype, get_model_block
from petals.utils.convert_block import QuantType, convert_block
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
logger = get_logger(__name__)

@ -12,8 +12,6 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
from petals.models.mixtral import WrappedMixtralBlocks
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
@ -156,15 +154,3 @@ def check_device_balance(devices: Sequence[torch.device]):
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
f"Consider running high-memory GPUs in a separate server."
)
def get_model_block(config, **kwargs):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
They will not be passed to other block constructors.
"""
if config.block_class == WrappedMixtralBlocks:
PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config)

Loading…
Cancel
Save