better status logs

8bit_blocks
justheuristic 2 years ago
parent 1555d98f66
commit 84de19fb1a

@ -40,6 +40,7 @@ if __name__ == "__main__":
):
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
logger.info(f"Loading source model {args.model}")
model = transformers.AutoModelForCausalLM.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
@ -52,6 +53,8 @@ if __name__ == "__main__":
repo.git_pull()
transformer_blocks = model.transformer.h
logger.info(f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}")
for i, block in enumerate(tqdm(transformer_blocks)):
repo.git_checkout(args.base_branch, create_branch_ok=True)
with repo.commit(
@ -59,6 +62,7 @@ if __name__ == "__main__":
):
torch.save(block.state_dict(), "./pytorch_model.bin")
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
repo.git_checkout(args.base_branch, create_branch_ok=True)
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
model.transformer.h = nn.ModuleList()

Loading…
Cancel
Save