set default DHT prefix

standardize
justheuristic 2 years ago
parent 41e5a95e8e
commit 90d65e58aa

@ -9,6 +9,9 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from huggingface_hub import Repository
from tqdm.auto import tqdm
from src import BloomModel
from src.client.remote_model import DistributedBloomConfig
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -44,10 +47,12 @@ if __name__ == "__main__":
raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
logger.info(f"Loading source model {args.model} (this may take a few minutes)")
config = transformers.AutoConfig.from_pretrained(
config = DistributedBloomConfig.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)
model = transformers.AutoModel.from_pretrained(
config.dht_prefix = args.model
model = BloomModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
tokenizer = transformers.AutoTokenizer.from_pretrained(

Loading…
Cancel
Save