|
|
@ -13,6 +13,8 @@ import time
|
|
|
|
from typing import Optional, OrderedDict, Union
|
|
|
|
from typing import Optional, OrderedDict, Union
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
|
|
|
|
from accelerate.utils import set_module_tensor_to_device
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from transformers.modeling_utils import WEIGHTS_NAME
|
|
|
|
from transformers.modeling_utils import WEIGHTS_NAME
|
|
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
@ -38,13 +40,16 @@ def load_pretrained_block(
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
) -> WrappedBloomBlock:
|
|
|
|
) -> WrappedBloomBlock:
|
|
|
|
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
|
|
"""Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
|
|
|
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
|
|
|
|
|
|
if config is None:
|
|
|
|
if config is None:
|
|
|
|
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
|
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
|
|
|
|
if cache_dir is None:
|
|
|
|
if cache_dir is None:
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
|
|
|
|
|
|
|
block = WrappedBloomBlock(config)
|
|
|
|
with init_empty_weights():
|
|
|
|
|
|
|
|
block = WrappedBloomBlock(config)
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = _load_state_dict(
|
|
|
|
state_dict = _load_state_dict(
|
|
|
|
converted_model_name_or_path,
|
|
|
|
converted_model_name_or_path,
|
|
|
|
block_index,
|
|
|
|
block_index,
|
|
|
@ -54,16 +59,17 @@ def load_pretrained_block(
|
|
|
|
max_disk_space=max_disk_space,
|
|
|
|
max_disk_space=max_disk_space,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if torch_dtype == "auto":
|
|
|
|
# dummy load, check that keys match
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
for name, param in block.named_parameters():
|
|
|
|
|
|
|
|
assert name in state_dict, f"{name} not in state dict"
|
|
|
|
|
|
|
|
param.data = param.data.to(state_dict[name].dtype)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
|
|
|
block = block.to(dtype=torch_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
report = block.load_state_dict(state_dict, strict=True)
|
|
|
|
report = block.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for param_name, _ in block.named_parameters():
|
|
|
|
|
|
|
|
assert param_name in state_dict, f"{param_name} not in state dict"
|
|
|
|
|
|
|
|
param = state_dict[param_name]
|
|
|
|
|
|
|
|
if torch_dtype != "auto" and not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
|
|
|
|
|
|
|
param = param.to(torch_dtype)
|
|
|
|
|
|
|
|
set_module_tensor_to_device(block, param_name, "cpu", value=param)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
|
|
logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
|
|
|
|
return block
|
|
|
|
return block
|
|
|
|
|
|
|
|
|
|
|
|