Enable rebalancing by default (#84)

pull/85/head
Alexander Borzunov 2 years ago committed by GitHub
parent 2cb82dd648
commit ee4e69c254
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,5 @@
import argparse
import configargparse
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.limits import increase_file_limit
@ -12,7 +14,8 @@ logger = get_logger(__file__)
def main():
# fmt:off
parser = configargparse.ArgParser(default_config_files=["config.yml"])
parser = configargparse.ArgParser(default_config_files=["config.yml"],
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
group = parser.add_mutually_exclusive_group(required=True)
@ -80,10 +83,11 @@ def main():
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--min_balance_quality", type=float, default=0.0,
help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) "
"goes below this threshold. Default: rebalancing is disabled")
parser.add_argument("--mean_balance_check_period", type=float, default=150,
parser.add_argument("--balance_quality", type=float, default=0.75,
help="Rebalance the swarm if its throughput is worse than this share of the optimal "
"throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
"on each check for debugging purposes.")
parser.add_argument("--mean_balance_check_period", type=float, default=60,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")

@ -61,7 +61,7 @@ class TransformerBackend(ModuleBackend):
if not is_dummy(hypo_ids):
cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
hidden_states, (new_k, new_v) = self.module.forward(
hidden_states, layer_past=layer_past, use_cache=True
)

@ -62,9 +62,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
) -> bool:
if min_balance_quality > 1.0:
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans, throughputs = _compute_spans(module_infos)
@ -99,8 +99,8 @@ def should_choose_other_blocks(
throughputs[span.start : span.end] += span.throughput
new_throughput = throughputs.min()
balance_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")
actual_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
eps = 1e-6
return balance_quality < min_balance_quality - eps
return actual_quality < balance_quality - eps

@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import CHAIN_DELIMITER, ModuleUID
@ -24,6 +25,8 @@ from src.server.task_pool import PrioritizedTaskPool
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
from src.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
@ -73,7 +76,7 @@ class TransformerConnectionHandler(ConnectionHandler):
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
try:
print("OPENED RPC_INFERENCE")
logger.debug("Opened rpc_inference()")
request = await anext(requests)
requested_uids = self._check_uids(request.uid)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
@ -164,7 +167,7 @@ class TransformerConnectionHandler(ConnectionHandler):
prefix_length += hidden_states.shape[1]
request = await (anext(requests))
finally:
print("CLOSED RPC_INFERENCE")
logger.debug("Closed rpc_inference()")
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse request and prepare backends

@ -61,8 +61,8 @@ class Server(threading.Thread):
expiration: Optional[float] = None,
prefetch_batches: int = 1,
sender_threads: int = 1,
min_balance_quality: float = 0.0,
mean_balance_check_period: float = 150,
balance_quality: float = 0.75,
mean_balance_check_period: float = 60,
mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
@ -138,7 +138,7 @@ class Server(threading.Thread):
raise
block_indices = range(first_block_index, last_block_index)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
self.min_balance_quality = min_balance_quality
self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
@ -215,7 +215,7 @@ class Server(threading.Thread):
return False
module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
def shutdown(self):
self.stop.set()

Loading…
Cancel
Save