|
|
|
@ -9,6 +9,9 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
from huggingface_hub import Repository
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
from src import BloomModel
|
|
|
|
|
from src.client.remote_model import DistributedBloomConfig
|
|
|
|
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
@ -44,10 +47,12 @@ if __name__ == "__main__":
|
|
|
|
|
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
|
|
|
|
|
|
|
|
|
|
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
|
|
|
|
|
config = transformers.AutoConfig.from_pretrained(
|
|
|
|
|
config = DistributedBloomConfig.from_pretrained(
|
|
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
|
|
|
)
|
|
|
|
|
model = transformers.AutoModel.from_pretrained(
|
|
|
|
|
config.dht_prefix = args.model
|
|
|
|
|
|
|
|
|
|
model = BloomModel.from_pretrained(
|
|
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
|
|
)
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
|
|