mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
push converted model to hub
This commit is contained in:
parent
15d0ea7129
commit
736f1d1085
@ -17,9 +17,10 @@ conda activate bloom-demo
|
||||
|
||||
conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
|
||||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install accelerate==0.10.0 huggingface-hub==0.7.0
|
||||
pip install bitsandbytes-cuda113==0.26.0
|
||||
pip install https://github.com/learning-at-home/hivemind/archive/master.zip
|
||||
pip install https://github.com/huggingface/transformers/archive/224bde91caff4ccfd12277ab5e9bf97c61e22ee9.zip
|
||||
pip install https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
|
||||
```
|
||||
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
|
||||
import psutil
|
||||
import torch.backends.quantized
|
||||
import transformers
|
||||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from tqdm.auto import trange
|
||||
from huggingface_hub import Repository
|
||||
import torch.nn as nn
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__file__)
|
||||
@ -16,16 +17,22 @@ DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.f
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
|
||||
parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
|
||||
|
||||
parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
|
||||
parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
|
||||
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
||||
parser.add_argument("--output_path", type=str, default='./converted_model', help="Track output repo to this folder")
|
||||
parser.add_argument("--output_repo", type=str, default='bigscience/test-bloomd', help="Push to this HF hub repo")
|
||||
parser.add_argument("--base_branch", type=str, default='main', help="Use this branch as reference point")
|
||||
parser.add_argument("--client_branch", type=str, default='client', help="Save client version to this branch")
|
||||
parser.add_argument("--block_branch_prefix", type=str, default='block_', help="Save blocks to branches with this prefix")
|
||||
parser.add_argument("--commit_message", type=str, default='push-o-matic', help="Use this commit message for all parts")
|
||||
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
||||
args = parser.parse_args()
|
||||
|
||||
free_ram_gb = psutil.virtual_memory().available / 2**30
|
||||
free_ram_gb = psutil.virtual_memory().available / 2 ** 30
|
||||
if free_ram_gb < 400:
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
|
||||
|
||||
assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
|
||||
if os.path.exists(args.output_path) and (
|
||||
@ -33,21 +40,28 @@ if __name__ == "__main__":
|
||||
):
|
||||
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
||||
|
||||
model = transformers.BloomForCausalLM.from_pretrained(
|
||||
model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
||||
)
|
||||
|
||||
qconfig = torch.quantization.get_default_qconfig("fbgemm")
|
||||
torch.backends.quantized.engine = "fbgemm"
|
||||
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
||||
)
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
|
||||
for i in trange(len(model.transformer.h)):
|
||||
layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
|
||||
layer_quantized = torch.quantization.quantize_dynamic(
|
||||
layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
|
||||
)
|
||||
torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
|
||||
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
||||
repo.git_pull()
|
||||
|
||||
model.transformer.h = torch.nn.ModuleList()
|
||||
torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))
|
||||
transformer_blocks = model.transformer.h
|
||||
for i, block in enumerate(tqdm(transformer_blocks)):
|
||||
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
||||
with repo.commit(
|
||||
commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
|
||||
):
|
||||
print(block.self_attention.layer_number)
|
||||
torch.save(block.state_dict(), "./pytorch_model.bin")
|
||||
|
||||
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
||||
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
||||
model.transformer.h = nn.ModuleList()
|
||||
model.save_pretrained(".")
|
||||
logger.info(f"Converted {args.model} and saved to {args.output_repo}")
|
||||
|
Loading…
Reference in New Issue
Block a user