|
|
|
@ -12,6 +12,8 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
from tensor_parallel.slicing_configs import get_bloom_config
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
|
|
|
|
from petals.models.mixtral import WrappedMixtralBlocks
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
@ -154,3 +156,15 @@ def check_device_balance(devices: Sequence[torch.device]):
|
|
|
|
|
f"GPU devices have highly uneven memory, {wasted_memory_rate * 100:.2f}% memory is wasted. "
|
|
|
|
|
f"Consider running high-memory GPUs in a separate server."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model_block(config, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
The function to create a model block based on the block class
|
|
|
|
|
kwargs argument **only** is necessary for specific classes, like Mixtral.
|
|
|
|
|
They will not be passed to other block constructors.
|
|
|
|
|
"""
|
|
|
|
|
if config.block_class == WrappedMixtralBlocks:
|
|
|
|
|
PreTrainedModel._autoset_attn_implementation(config)
|
|
|
|
|
return config.block_class(config, kwargs.get("layer_idx", 0))
|
|
|
|
|
return config.block_class(config)
|
|
|
|
|