Fix get_model_block

pull/570/head
Artem Chumachenko 2 months ago
parent 2ca531623c
commit ba271dc626

@ -135,7 +135,6 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
return self.norm
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):

@ -17,7 +17,7 @@ from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.server.block_utils import get_model_block, resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
) -> int:
"""Get the number of extra bytes used to store a set of adapters per given block"""
with init_empty_weights(include_buffers=True):
block = block_config.block_class(block_config)
block = get_model_block(block_config)
base_block_parameters = sum(p.numel() for p in block.parameters())
create_lora_adapter(block, quant_type=QuantType.NONE)

@ -7,6 +7,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
from petals.server.block_utils import get_model_block
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, convert_block
from test_utils import MODEL_NAME
@ -195,8 +196,9 @@ def test_optimized_block(device):
dtype = torch.bfloat16
quant_type = QuantType.NONE
block = config.block_class(config).to(dtype)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block_idx = 1
block = get_model_block(config, layer_idx=block_idx).to(dtype)
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if config.model_type == "falcon":
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
@ -206,7 +208,7 @@ def test_optimized_block(device):
pytest.skip(f"This test is not applicable to {config.model_type} models")
unopt_block = convert_block(
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())

Loading…
Cancel
Save