|
|
|
@ -11,18 +11,18 @@ from __future__ import annotations
|
|
|
|
|
from typing import Optional, OrderedDict, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import use_hivemind_log_handler, get_logger
|
|
|
|
|
from transformers.utils.hub import hf_bucket_url, cached_path
|
|
|
|
|
|
|
|
|
|
from src.bloom import BloomForCausalLM, DistributedBloomConfig, BloomBlock
|
|
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
from transformers.modeling_utils import WEIGHTS_NAME
|
|
|
|
|
from transformers.utils.hub import cached_path, hf_bucket_url
|
|
|
|
|
|
|
|
|
|
from src.bloom import BloomBlock, BloomForCausalLM, DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
CLIENT_BRANCH = "client"
|
|
|
|
|
BLOCK_BRANCH_PREFIX = "block_"
|
|
|
|
|
USER_AGENT = {'file_type': 'model', 'framework': 'pytorch', 'from_auto_class': False}
|
|
|
|
|
USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
|
|
|
|
|
cls = BloomForCausalLM
|
|
|
|
|
FORCE_DOWNLOAD = False
|
|
|
|
|
RESUME_DOWNLOAD = False
|
|
|
|
@ -30,8 +30,11 @@ LOCAL_FILES_ONLY = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pretrained_block(
|
|
|
|
|
converted_model_name_or_path: str, block_index: int,
|
|
|
|
|
config: Optional[DistributedBloomConfig] = None, torch_dtype: Union[torch.dtype, str] = 'auto') -> BloomBlock:
|
|
|
|
|
converted_model_name_or_path: str,
|
|
|
|
|
block_index: int,
|
|
|
|
|
config: Optional[DistributedBloomConfig] = None,
|
|
|
|
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
|
|
|
) -> BloomBlock:
|
|
|
|
|
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
|
|
|
|
|
if config is None:
|
|
|
|
|
config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
|
|
|
|
@ -39,7 +42,7 @@ def load_pretrained_block(
|
|
|
|
|
state_dict = _load_state_dict(converted_model_name_or_path, block_index)
|
|
|
|
|
block.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
if torch_dtype == 'auto':
|
|
|
|
|
if torch_dtype == "auto":
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for name, param in block.named_parameters():
|
|
|
|
|
assert name in state_dict, f"{name} not in state dict"
|
|
|
|
@ -54,7 +57,8 @@ def load_pretrained_block(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_state_dict(
|
|
|
|
|
pretrained_model_name_or_path: str, block_index: Optional[int] = None) -> OrderedDict[str, torch.Tensor]:
|
|
|
|
|
pretrained_model_name_or_path: str, block_index: Optional[int] = None
|
|
|
|
|
) -> OrderedDict[str, torch.Tensor]:
|
|
|
|
|
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
|
|
|
|
|
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
|
|
|
|
|
|
|
|
|
@ -69,7 +73,7 @@ def _load_state_dict(
|
|
|
|
|
use_auth_token=True,
|
|
|
|
|
user_agent=USER_AGENT,
|
|
|
|
|
)
|
|
|
|
|
state_dict = torch.load(resolved_archive_file, map_location='cpu')
|
|
|
|
|
state_dict = torch.load(resolved_archive_file, map_location="cpu")
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|