|
|
|
@ -10,8 +10,8 @@ from huggingface_hub import Repository
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
from src import BloomModel
|
|
|
|
|
from src.client.remote_model import DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
from src.client import DistributedBloomConfig
|
|
|
|
|
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
@ -26,9 +26,9 @@ if __name__ == "__main__":
|
|
|
|
|
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
|
|
|
|
|
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
|
|
|
|
|
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
|
|
|
|
|
parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
|
|
|
|
|
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
|
|
|
|
|
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
|
|
|
|
|
)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
|
|
@ -50,7 +50,7 @@ if __name__ == "__main__":
|
|
|
|
|
config = DistributedBloomConfig.from_pretrained(
|
|
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
|
|
|
)
|
|
|
|
|
config.dht_prefix = args.model
|
|
|
|
|
config.dht_prefix = args.output_repo
|
|
|
|
|
|
|
|
|
|
model = BloomModel.from_pretrained(
|
|
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
|
|