fix imports

standardize
justheuristic 2 years ago
parent 88c1bf9896
commit 5695897620

@ -10,8 +10,8 @@ from huggingface_hub import Repository
from tqdm.auto import tqdm
from src import BloomModel
from src.client.remote_model import DistributedBloomConfig
from src.client import DistributedBloomConfig
from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -26,9 +26,9 @@ if __name__ == "__main__":
parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
parser.add_argument(
"--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
"--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
)
parser.add_argument(
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
@ -50,7 +50,7 @@ if __name__ == "__main__":
config = DistributedBloomConfig.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)
config.dht_prefix = args.model
config.dht_prefix = args.output_repo
model = BloomModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]

@ -6,14 +6,13 @@ from typing import Dict, Optional, Sequence, Union
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.dht_handler import DHTHandlerThread
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
from src import declare_active_modules, BloomConfig
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
@ -140,7 +139,7 @@ class Server(threading.Thread):
assert num_blocks is not None
block_indices = range(num_blocks) # TODO replace with proper load balancing
block_config = DistributedBloomConfig.from_pretrained(
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token
)

Loading…
Cancel
Save