|
|
|
@ -7,6 +7,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
|
|
|
|
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
|
|
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
|
|
|
|
|
|
|
|
from petals.server.block_utils import get_model_block
|
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
from petals.utils.convert_block import QuantType, convert_block
|
|
|
|
|
from test_utils import MODEL_NAME
|
|
|
|
@ -195,8 +196,9 @@ def test_optimized_block(device):
|
|
|
|
|
dtype = torch.bfloat16
|
|
|
|
|
quant_type = QuantType.NONE
|
|
|
|
|
|
|
|
|
|
block = config.block_class(config).to(dtype)
|
|
|
|
|
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
|
|
|
|
block_idx = 1
|
|
|
|
|
block = get_model_block(config, layer_idx=block_idx).to(dtype)
|
|
|
|
|
block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
|
|
|
|
|
|
|
|
|
if config.model_type == "falcon":
|
|
|
|
|
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
|
|
|
|
@ -206,7 +208,7 @@ def test_optimized_block(device):
|
|
|
|
|
pytest.skip(f"This test is not applicable to {config.model_type} models")
|
|
|
|
|
|
|
|
|
|
unopt_block = convert_block(
|
|
|
|
|
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
|
|
|
|
|
unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
unopt_block.load_state_dict(block.state_dict())
|
|
|
|
|