diff --git a/cli/convert_model.py b/cli/convert_model.py index 98556cf..a7ffd63 100644 --- a/cli/convert_model.py +++ b/cli/convert_model.py @@ -10,8 +10,8 @@ from huggingface_hub import Repository from tqdm.auto import tqdm from src import BloomModel -from src.client.remote_model import DistributedBloomConfig - +from src.client import DistributedBloomConfig +from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX use_hivemind_log_handler("in_root_logger") logger = get_logger(__file__) @@ -26,9 +26,9 @@ if __name__ == "__main__": parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype") parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder") parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo") - parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch") + parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch") parser.add_argument( - "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix" + "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix" ) parser.add_argument( "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts" @@ -50,7 +50,7 @@ if __name__ == "__main__": config = DistributedBloomConfig.from_pretrained( args.model, use_auth_token=args.use_auth_token, revision=args.revision ) - config.dht_prefix = args.model + config.dht_prefix = args.output_repo model = BloomModel.from_pretrained( args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype] diff --git a/src/server/server.py b/src/server/server.py index 41f63d3..7eb0335 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -6,14 +6,13 @@ from typing import Dict, Optional, Sequence, Union import torch from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time -from hivemind.moe.server.dht_handler import DHTHandlerThread from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger, use_hivemind_log_handler -from src import declare_active_modules -from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block +from src import declare_active_modules, BloomConfig +from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER from src.server.backend import TransformerBackend from src.server.cache import MemoryCache @@ -140,7 +139,7 @@ class Server(threading.Thread): assert num_blocks is not None block_indices = range(num_blocks) # TODO replace with proper load balancing - block_config = DistributedBloomConfig.from_pretrained( + block_config = BloomConfig.from_pretrained( converted_model_name_or_path, use_auth_token=use_auth_token )