This commit is contained in:
dbaranchuk 2022-08-03 12:47:36 +03:00
parent 5200dc7029
commit e297ae606f

View File

@ -34,13 +34,15 @@ def load_pretrained_block(
config: Optional[BloomConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None
cache_dir: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
block = BloomBlock(config, layer_number=block_index)
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir)
state_dict = _load_state_dict(
converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
)
block.load_state_dict(state_dict)
if torch_dtype == "auto":
@ -58,10 +60,10 @@ def load_pretrained_block(
def _load_state_dict(
pretrained_model_name_or_path: str,
block_index: Optional[int] = None,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None
pretrained_model_name_or_path: str,
block_index: Optional[int] = None,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)