2022-06-12 02:59:11 +00:00
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
|
|
|
|
import psutil
|
|
|
|
import torch.backends.quantized
|
2022-06-20 13:50:22 +00:00
|
|
|
import torch.nn as nn
|
2022-06-12 02:59:11 +00:00
|
|
|
import transformers
|
2022-06-14 05:25:06 +00:00
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
2022-06-19 16:06:35 +00:00
|
|
|
from huggingface_hub import Repository
|
|
|
|
from tqdm.auto import tqdm
|
2022-06-12 02:59:11 +00:00
|
|
|
|
2022-07-07 00:34:58 +00:00
|
|
|
from src import BloomModel
|
2022-07-07 01:13:20 +00:00
|
|
|
from src.client import DistributedBloomConfig
|
|
|
|
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
|
2022-06-14 05:25:06 +00:00
|
|
|
use_hivemind_log_handler("in_root_logger")
|
2022-06-12 02:59:11 +00:00
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
2022-06-19 16:06:35 +00:00
|
|
|
|
|
|
|
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
2022-06-12 02:59:11 +00:00
|
|
|
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
|
|
|
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
2022-06-19 16:18:46 +00:00
|
|
|
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
|
|
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
2022-07-07 01:13:20 +00:00
|
|
|
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
2022-06-19 16:18:46 +00:00
|
|
|
parser.add_argument(
|
2022-07-07 01:13:20 +00:00
|
|
|
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
2022-06-19 16:18:46 +00:00
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
|
|
)
|
2022-06-12 02:59:11 +00:00
|
|
|
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2022-06-19 16:18:46 +00:00
|
|
|
free_ram_gb = psutil.virtual_memory().available / 2**30
|
|
|
|
if args.model == "bigscience/bloom" and free_ram_gb < 400:
|
2022-06-19 16:06:35 +00:00
|
|
|
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
2022-06-12 02:59:11 +00:00
|
|
|
|
|
|
|
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
|
|
|
if os.path.exists(args.output_path) and (
|
|
|
|
len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
|
|
|
|
):
|
|
|
|
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
|
|
|
|
2022-06-19 16:18:46 +00:00
|
|
|
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
2022-07-07 00:34:58 +00:00
|
|
|
config = DistributedBloomConfig.from_pretrained(
|
2022-06-20 11:28:31 +00:00
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
|
|
)
|
2022-07-07 01:13:20 +00:00
|
|
|
config.dht_prefix = args.output_repo
|
2022-07-07 00:34:58 +00:00
|
|
|
|
|
|
|
model = BloomModel.from_pretrained(
|
2022-06-12 02:59:11 +00:00
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
|
)
|
2022-06-19 16:06:35 +00:00
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
|
|
)
|
2022-06-12 02:59:11 +00:00
|
|
|
os.makedirs(args.output_path, exist_ok=True)
|
|
|
|
|
2022-06-19 16:06:35 +00:00
|
|
|
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
|
|
repo.git_pull()
|
|
|
|
|
2022-07-04 19:43:51 +00:00
|
|
|
transformer_blocks = model.h
|
2022-06-19 16:18:46 +00:00
|
|
|
logger.info(
|
|
|
|
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
|
|
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
|
|
)
|
2022-06-19 16:06:35 +00:00
|
|
|
for i, block in enumerate(tqdm(transformer_blocks)):
|
2022-07-07 00:49:04 +00:00
|
|
|
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
2022-06-19 16:06:35 +00:00
|
|
|
with repo.commit(
|
2022-06-19 16:18:46 +00:00
|
|
|
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
2022-06-19 16:06:35 +00:00
|
|
|
):
|
|
|
|
torch.save(block.state_dict(), "./pytorch_model.bin")
|
2022-06-12 06:35:58 +00:00
|
|
|
|
2022-06-19 16:17:44 +00:00
|
|
|
logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
|
2022-07-07 00:16:47 +00:00
|
|
|
repo.git_checkout(args.client_branch, create_branch_ok=True)
|
2022-06-19 16:06:35 +00:00
|
|
|
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
2022-07-04 18:18:29 +00:00
|
|
|
model.h = nn.ModuleList()
|
2022-06-19 16:06:35 +00:00
|
|
|
model.save_pretrained(".")
|
2022-06-20 11:28:31 +00:00
|
|
|
tokenizer.save_pretrained(".")
|
|
|
|
config.save_pretrained(".")
|
|
|
|
|
|
|
|
logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
|