|
|
@ -26,6 +26,7 @@ from petals.constants import DTYPE_MAP
|
|
|
|
from petals.models.mixtral import WrappedMixtralBlock
|
|
|
|
from petals.models.mixtral import WrappedMixtralBlock
|
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
|
|
|
from petals.utils.convert_block import get_model_block
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
|
|
from petals.utils.hf_auth import always_needs_auth
|
|
|
|
from petals.utils.hf_auth import always_needs_auth
|
|
|
|
|
|
|
|
|
|
|
@ -52,11 +53,7 @@ def load_pretrained_block(
|
|
|
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
|
|
|
torch_dtype = resolve_block_dtype(config, torch_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
with init_empty_weights():
|
|
|
|
with init_empty_weights():
|
|
|
|
if config.block_class == WrappedMixtralBlock:
|
|
|
|
block = get_model_block(config, layer_idx=block_index)
|
|
|
|
config = PreTrainedModel._autoset_attn_implementation(config)
|
|
|
|
|
|
|
|
block = config.block_class(config, block_index)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
block = config.block_class(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
|
|
|
block_prefix = f"{config.block_prefix}.{block_index}."
|
|
|
|
state_dict = _load_state_dict_from_repo(
|
|
|
|
state_dict = _load_state_dict_from_repo(
|
|
|
|