|
|
|
@ -23,7 +23,6 @@ if __name__ == "__main__":
|
|
|
|
|
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
|
|
|
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
|
|
|
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
|
|
|
parser.add_argument("--base_branch", type=str, default="main", help="Use this branch as reference point")
|
|
|
|
|
parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
|
|
|
|
@ -65,23 +64,21 @@ if __name__ == "__main__":
|
|
|
|
|
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
|
|
|
)
|
|
|
|
|
for i, block in enumerate(tqdm(transformer_blocks)):
|
|
|
|
|
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
|
|
|
|
with repo.commit(
|
|
|
|
|
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
|
|
|
|
):
|
|
|
|
|
torch.save(block.state_dict(), "./pytorch_model.bin")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
|
|
|
|
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
|
|
|
|
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
|
|
|
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
|
|
|
model.h = nn.ModuleList()
|
|
|
|
|
model.save_pretrained(".")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
|
|
|
|
|
|
|
|
|
|
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
|
|
|
|
with repo.commit(commit_message=args.commit_message, branch=args.base_branch, track_large_files=True):
|
|
|
|
|
tokenizer.save_pretrained(".")
|
|
|
|
|
config.save_pretrained(".")
|
|
|
|
|
|
|
|
|
|
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
|
|
|
|
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|
|
|
|
|