Merge branch 'fix-auth-token' into main

pull/18/head
justheuristic 2 years ago
commit 894cd5d586

@ -55,6 +55,7 @@ def main():
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
# fmt:on
args = vars(parser.parse_args())
@ -66,6 +67,9 @@ def main():
compression_type = args.pop("compression")
compression = getattr(CompressionType, compression_type)
use_auth_token = args.pop("use_auth_token")
args['use_auth_token'] = True if use_auth_token in ('True', 'true', '') else use_auth_token
server = Server.create(**args, start=True, compression=compression)
try:

@ -34,12 +34,13 @@ def load_pretrained_block(
block_index: int,
config: Optional[DistributedBloomConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str]=None
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:
config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path)
config = DistributedBloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
block = BloomBlock(config, layer_number=block_index)
state_dict = _load_state_dict(converted_model_name_or_path, block_index)
state_dict = _load_state_dict(converted_model_name_or_path, block_index, use_auth_token=use_auth_token)
block.load_state_dict(state_dict)
if torch_dtype == "auto":
@ -57,7 +58,7 @@ def load_pretrained_block(
def _load_state_dict(
pretrained_model_name_or_path: str, block_index: Optional[int] = None
pretrained_model_name_or_path: str, block_index: Optional[int] = None, use_auth_token: Optional[str] = None
) -> OrderedDict[str, torch.Tensor]:
revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
@ -70,7 +71,7 @@ def _load_state_dict(
proxies=None,
resume_download=RESUME_DOWNLOAD,
local_files_only=LOCAL_FILES_ONLY,
use_auth_token=True,
use_auth_token=use_auth_token,
user_agent=USER_AGENT,
)
state_dict = torch.load(resolved_archive_file, map_location="cpu")

@ -100,6 +100,7 @@ class Server(threading.Thread):
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
use_auth_token: Optional[str] = None,
*,
start: bool,
**kwargs,
@ -121,12 +122,12 @@ class Server(threading.Thread):
if block_indices is not None:
try:
start, end = block_indices.split(":")
start, end = map(int, map(str.strip, (start, end)))
first_block_index, last_block_index = block_indices.split(":")
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
except Exception as e:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
raise
block_indices = range(start, end)
block_indices = range(first_block_index, last_block_index)
else:
assert num_blocks is not None
block_indices = range(num_blocks) # TODO replace with proper load balancing
@ -137,7 +138,13 @@ class Server(threading.Thread):
blocks = {}
for block_index in block_indices:
module_uid = f"{prefix}.{block_index}"
block = load_pretrained_block(converted_model_name_or_path, block_index, block_config, torch_dtype)
block = load_pretrained_block(
converted_model_name_or_path,
block_index,
block_config,
torch_dtype=torch_dtype,
use_auth_token=use_auth_token
)
for param in block.parameters():
param.requires_grad = False

Loading…
Cancel
Save