|
|
|
@ -2,9 +2,10 @@ from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
|
|
|
|
|
|
|
|
from petals.utils.convert_block import QuantType, get_model_block
|
|
|
|
|
from petals.models.mixtral.block import WrappedMixtralBlock
|
|
|
|
|
from petals.utils.convert_block import QuantType
|
|
|
|
|
from petals.utils.misc import get_size_in_bytes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -50,3 +51,15 @@ def get_block_size(
|
|
|
|
|
bytes_per_value = get_size_in_bytes(dtype)
|
|
|
|
|
|
|
|
|
|
return round(n_params * bytes_per_value * (1 + eps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_block(config, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
The function to create a model block based on the block class
|
|
|
|
|
kwargs argument **only** is necessary for specific classes, like Mixtral.
|
|
|
|
|
They will not be passed to other block constructors.
|
|
|
|
|
"""
|
|
|
|
|
if config.block_class == WrappedMixtralBlock:
|
|
|
|
|
PreTrainedModel._autoset_attn_implementation(config)
|
|
|
|
|
return config.block_class(config, kwargs.get("layer_idx", 0))
|
|
|
|
|
return config.block_class(config)
|
|
|
|
|