|
|
|
@ -53,7 +53,7 @@ def get_block_size(
|
|
|
|
|
return round(n_params * bytes_per_value * (1 + eps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_block(config, **kwargs):
|
|
|
|
|
def get_model_block(config, layer_idx: int = 0):
|
|
|
|
|
"""
|
|
|
|
|
The function to create a model block based on the block class
|
|
|
|
|
kwargs argument **only** is necessary for specific classes, like Mixtral.
|
|
|
|
@ -61,5 +61,5 @@ def get_model_block(config, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
if config.block_class == WrappedMixtralBlock:
|
|
|
|
|
config = PreTrainedModel._autoset_attn_implementation(config)
|
|
|
|
|
return config.block_class(config, kwargs.get("layer_idx", 0))
|
|
|
|
|
return config.block_class(config, layer_idx)
|
|
|
|
|
return config.block_class(config)
|
|
|
|
|