From 899cefe5883953ac5699951b1c656166a4e5ce74 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Thu, 7 Jul 2022 03:16:47 +0300 Subject: [PATCH] set client branch to main by default; remove the concept of base branch (redundant) --- cli/convert_model.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cli/convert_model.py b/cli/convert_model.py index 667c199..a80a3d7 100644 --- a/cli/convert_model.py +++ b/cli/convert_model.py @@ -23,7 +23,6 @@ if __name__ == "__main__": parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype") parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") - parser.add_argument("--base_branch", type=str, default="main", help="Use this branch as reference point") parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch") parser.add_argument( "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix" @@ -65,23 +64,21 @@ if __name__ == "__main__": f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}" ) for i, block in enumerate(tqdm(transformer_blocks)): - repo.git_checkout(args.base_branch, create_branch_ok=True) with repo.commit( commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True ): torch.save(block.state_dict(), "./pytorch_model.bin") logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}") - repo.git_checkout(args.base_branch, create_branch_ok=True) + repo.git_checkout(args.client_branch, create_branch_ok=True) with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True): model.h = nn.ModuleList() model.save_pretrained(".") - - logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}") - - repo.git_checkout(args.base_branch, create_branch_ok=True) - with repo.commit(commit_message=args.commit_message, branch=args.base_branch, track_large_files=True): tokenizer.save_pretrained(".") config.save_pretrained(".") + repo.git_checkout(args.client_branch, create_branch_ok=True) + with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True): + + logger.info(f"Converted {args.model} and pushed to {args.output_repo}")