From ee4e69c25492e6b73b830ff379b3f05ebf506a9b Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 2 Nov 2022 00:50:01 +0400 Subject: [PATCH] Enable rebalancing by default (#84) --- cli/run_server.py | 14 +++++++++----- src/server/backend.py | 2 +- src/server/block_selection.py | 10 +++++----- src/server/handler.py | 7 +++++-- src/server/server.py | 8 ++++---- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/cli/run_server.py b/cli/run_server.py index e94d2e3..fcef351 100644 --- a/cli/run_server.py +++ b/cli/run_server.py @@ -1,3 +1,5 @@ +import argparse + import configargparse from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.limits import increase_file_limit @@ -12,7 +14,8 @@ logger = get_logger(__file__) def main(): # fmt:off - parser = configargparse.ArgParser(default_config_files=["config.yml"]) + parser = configargparse.ArgParser(default_config_files=["config.yml"], + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add('-c', '--config', required=False, is_config_file=True, help='config file path') group = parser.add_mutually_exclusive_group(required=True) @@ -80,10 +83,11 @@ def main(): help='Path of a file with custom nn.modules, wrapped into special decorator') parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P') - parser.add_argument("--min_balance_quality", type=float, default=0.0, - help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) " - "goes below this threshold. Default: rebalancing is disabled") - parser.add_argument("--mean_balance_check_period", type=float, default=150, + parser.add_argument("--balance_quality", type=float, default=0.75, + help="Rebalance the swarm if its throughput is worse than this share of the optimal " + "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing " + "on each check for debugging purposes.") + parser.add_argument("--mean_balance_check_period", type=float, default=60, help="Check the swarm's balance every N seconds (and rebalance it if necessary)") parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") diff --git a/src/server/backend.py b/src/server/backend.py index ed8273d..00f55dd 100644 --- a/src/server/backend.py +++ b/src/server/backend.py @@ -61,7 +61,7 @@ class TransformerBackend(ModuleBackend): if not is_dummy(hypo_ids): cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length] - print("METADATA:", cache_metadata, past_k.shape, past_v.shape) + logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}") hidden_states, (new_k, new_v) = self.module.forward( hidden_states, layer_past=layer_past, use_cache=True ) diff --git a/src/server/block_selection.py b/src/server/block_selection.py index af875c2..c6352b5 100644 --- a/src/server/block_selection.py +++ b/src/server/block_selection.py @@ -62,9 +62,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule def should_choose_other_blocks( - local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float + local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float ) -> bool: - if min_balance_quality > 1.0: + if balance_quality > 1.0: return True # Forces rebalancing on each check (may be used for debugging purposes) spans, throughputs = _compute_spans(module_infos) @@ -99,8 +99,8 @@ def should_choose_other_blocks( throughputs[span.start : span.end] += span.throughput new_throughput = throughputs.min() - balance_quality = initial_throughput / new_throughput - logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%") + actual_quality = initial_throughput / new_throughput + logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%") eps = 1e-6 - return balance_quality < min_balance_quality - eps + return actual_quality < balance_quality - eps diff --git a/src/server/handler.py b/src/server/handler.py index 7d7f76b..3c366e3 100644 --- a/src/server/handler.py +++ b/src/server/handler.py @@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter +from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming from src.data_structures import CHAIN_DELIMITER, ModuleUID @@ -24,6 +25,8 @@ from src.server.task_pool import PrioritizedTaskPool from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase from src.utils.misc import DUMMY, is_dummy +logger = get_logger(__file__) + class TransformerConnectionHandler(ConnectionHandler): """Handles three request types: forward, backward and forward-incremental (inference)""" @@ -73,7 +76,7 @@ class TransformerConnectionHandler(ConnectionHandler): ) -> AsyncIterator[runtime_pb2.ExpertRequest]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" try: - print("OPENED RPC_INFERENCE") + logger.debug("Opened rpc_inference()") request = await anext(requests) requested_uids = self._check_uids(request.uid) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} @@ -164,7 +167,7 @@ class TransformerConnectionHandler(ConnectionHandler): prefix_length += hidden_states.shape[1] request = await (anext(requests)) finally: - print("CLOSED RPC_INFERENCE") + logger.debug("Closed rpc_inference()") async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: # Parse request and prepare backends diff --git a/src/server/server.py b/src/server/server.py index 1ed11ca..174762d 100644 --- a/src/server/server.py +++ b/src/server/server.py @@ -61,8 +61,8 @@ class Server(threading.Thread): expiration: Optional[float] = None, prefetch_batches: int = 1, sender_threads: int = 1, - min_balance_quality: float = 0.0, - mean_balance_check_period: float = 150, + balance_quality: float = 0.75, + mean_balance_check_period: float = 60, mean_block_selection_delay: float = 0.5, use_auth_token: Optional[str] = None, load_in_8bit: bool = False, @@ -138,7 +138,7 @@ class Server(threading.Thread): raise block_indices = range(first_block_index, last_block_index) self.strict_block_indices, self.num_blocks = block_indices, num_blocks - self.min_balance_quality = min_balance_quality + self.balance_quality = balance_quality self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay @@ -215,7 +215,7 @@ class Server(threading.Thread): return False module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf) - return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality) + return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality) def shutdown(self): self.stop.set()