|
|
|
@ -34,12 +34,12 @@ def load_pretrained_block(
|
|
|
|
|
config: Optional[PretrainedConfig] = None,
|
|
|
|
|
torch_dtype: Union[torch.dtype, str] = "auto",
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
token: Optional[str] = None,
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
|
) -> nn.Module:
|
|
|
|
|
if config is None:
|
|
|
|
|
config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
|
|
|
|
|
config = AutoDistributedConfig.from_pretrained(model_name, token=token)
|
|
|
|
|
if cache_dir is None:
|
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
|
|
|
|
@ -54,7 +54,7 @@ def load_pretrained_block(
|
|
|
|
|
model_name,
|
|
|
|
|
block_prefix,
|
|
|
|
|
revision=revision,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
max_disk_space=max_disk_space,
|
|
|
|
|
)
|
|
|
|
@ -82,12 +82,12 @@ def _load_state_dict_from_repo(
|
|
|
|
|
block_prefix: str,
|
|
|
|
|
*,
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
token: Optional[str] = None,
|
|
|
|
|
cache_dir: str,
|
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
|
) -> StateDict:
|
|
|
|
|
index_file = get_file_from_repo(
|
|
|
|
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
|
|
|
|
|
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
|
|
|
|
|
)
|
|
|
|
|
if index_file is not None: # Sharded model
|
|
|
|
|
with open(index_file) as f:
|
|
|
|
@ -107,7 +107,7 @@ def _load_state_dict_from_repo(
|
|
|
|
|
model_name,
|
|
|
|
|
filename,
|
|
|
|
|
revision=revision,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
max_disk_space=max_disk_space,
|
|
|
|
|
)
|
|
|
|
@ -125,7 +125,7 @@ def _load_state_dict_from_file(
|
|
|
|
|
filename: str,
|
|
|
|
|
*,
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
token: Optional[str] = None,
|
|
|
|
|
cache_dir: str,
|
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
|
delay: float = 30,
|
|
|
|
@ -137,7 +137,7 @@ def _load_state_dict_from_file(
|
|
|
|
|
model_name,
|
|
|
|
|
filename,
|
|
|
|
|
revision=revision,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
use_auth_token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
local_files_only=True,
|
|
|
|
|
)
|
|
|
|
@ -151,7 +151,7 @@ def _load_state_dict_from_file(
|
|
|
|
|
try:
|
|
|
|
|
with allow_cache_writes(cache_dir):
|
|
|
|
|
url = hf_hub_url(model_name, filename, revision=revision)
|
|
|
|
|
file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
|
|
|
|
file_size = get_hf_file_metadata(url, token=token).size
|
|
|
|
|
if file_size is not None:
|
|
|
|
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
|
|
else:
|
|
|
|
@ -161,7 +161,7 @@ def _load_state_dict_from_file(
|
|
|
|
|
model_name,
|
|
|
|
|
filename,
|
|
|
|
|
revision=revision,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
use_auth_token=token,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
local_files_only=False,
|
|
|
|
|
)
|
|
|
|
|