Rebalance swarm when necessary (#34)

update-hivemind
Alexander Borzunov 2 years ago committed by GitHub
parent 640bbc38a9
commit 149f433763
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,7 @@ def main():
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
@ -79,6 +79,13 @@ def main():
parser.add_argument('--custom_module_path', type=str, required=False,
help='Path of a file with custom nn.modules, wrapped into special decorator')
parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
parser.add_argument("--min_balance_quality", type=float, default=0.0,
help="Rebalance the swarm if its balance quality (a number in [0.0, 1.0]) "
"goes below this threshold. Default: rebalancing is disabled")
parser.add_argument("--mean_balance_check_period", type=float, default=150,
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
@ -104,7 +111,7 @@ def main():
use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server.create(**args, start=True, compression=compression, attn_cache_size=attn_cache_size)
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
try:
server.join()

@ -1,18 +1,103 @@
from typing import List, Optional
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
from hivemind import PeerID, get_logger
from src.data_structures import RemoteModuleInfo, ServerState
__all__ = ["choose_best_blocks", "should_choose_other_blocks"]
logger = get_logger(__file__)
@dataclass
class Span:
start: int
end: int
throughput: float
@property
def length(self):
return self.end - self.start
def move_to(self, new_start: int) -> None:
self.start, self.end = new_start, new_start + self.length
def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
throughputs = []
for module in remote_module_infos:
def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
spans = {}
throughputs = np.zeros(len(module_infos))
for block, module in enumerate(module_infos):
if module is None:
throughputs.append(0)
continue
throughputs.append(
sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
)
options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
best_start = min(options)[1]
return list(range(best_start, best_start + num_blocks))
for peer_id, server in module.servers.items():
if server.state == ServerState.OFFLINE:
continue
if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
throughputs[block] += server.throughput
return spans, throughputs
def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
options = (
(sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
for i in range(0, len(throughputs) - num_blocks + 1)
)
return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
_, throughputs = _compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks, None)
return list(range(start, start + num_blocks))
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
) -> bool:
spans, throughputs = _compute_spans(module_infos)
initial_throughput = throughputs.min()
assert local_peer_id in spans, "Span served by this server is not present in the DHT"
local_span = spans[local_peer_id]
throughputs[local_span.start : local_span.end] -= local_span.throughput
new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
if local_span.start == new_start:
return False # This server is on its best place already
local_span.move_to(new_start)
throughputs[local_span.start : local_span.end] += local_span.throughput
moved = True
while moved:
servers = list(spans.keys())
np.random.shuffle(servers)
moved = False
for peer_id in servers:
span = spans[peer_id]
throughputs[span.start : span.end] -= span.throughput
new_start = _choose_best_start(throughputs, span.length, span.start)
if span.start != new_start:
span.move_to(new_start)
moved = True
throughputs[span.start : span.end] += span.throughput
new_throughput = throughputs.min()
balance_quality = initial_throughput / new_throughput
logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")
eps = 1e-6
return balance_quality < min_balance_quality - eps

@ -4,8 +4,9 @@ import multiprocessing as mp
import random
import threading
import time
from typing import Dict, Optional, Sequence, Union
from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import torch
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
@ -17,8 +18,8 @@ from src import BloomConfig, declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from src.dht_utils import get_remote_module_infos
from src.server import block_selection
from src.server.backend import TransformerBackend
from src.server.block_selection import choose_best_blocks
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
from src.server.throughput import get_host_throughput
@ -29,76 +30,13 @@ logger = get_logger(__file__)
class Server(threading.Thread):
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
"""
Runs ModuleContainer, periodically checks that the network is balanced,
restarts the ModuleContainer with other layers if the imbalance is significant
"""
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int = 8,
throughput: float,
update_period: float = 30,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
threading.Thread.__init__(self)
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends,
dht,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
self.run_in_background(await_ready=True)
def run(self):
"""
Starts Server in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for block_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
parameter_msg = f"{num_parameters} trainable parameters" if num_parameters else "frozen"
logger.info(f"{block_name}: {backend.module.__class__.__name__}, {parameter_msg}")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
if self.module_backends:
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for process in self.conn_handlers:
if not process.is_alive():
process.start()
process.ready.result()
try:
self.runtime.run()
finally:
self.shutdown()
# noinspection PyMethodOverriding
@classmethod
def create(
cls,
prefix: Optional[str],
converted_model_name_or_path: str,
throughput: Union[float, str],
@ -121,16 +59,34 @@ class Server(threading.Thread):
expiration: Optional[float] = None,
prefetch_batches: int = 1,
sender_threads: int = 1,
max_block_selection_delay: float = 1,
min_balance_quality: float = 0.0,
mean_balance_check_period: float = 150,
mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
*,
start: bool,
**kwargs,
) -> Server:
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
super().__init__()
self.converted_model_name_or_path = converted_model_name_or_path
self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.cache_dir = cache_dir
self.attn_cache_size = attn_cache_size
self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
self.use_auth_token = use_auth_token
self.load_in_8bit = load_in_8bit
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)
if prefix is None:
prefix = converted_model_name_or_path
assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
@ -138,29 +94,39 @@ class Server(threading.Thread):
f"Please specify --prefix manually when starting a server"
)
logger.info(f"Automatic dht prefix: {prefix}")
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
self.prefix = prefix
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.expiration = expiration
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
memory_cache = MemoryCache(device, attn_cache_size)
self.device = device
self.memory_cache = MemoryCache(device, attn_cache_size)
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
self.throughput = throughput
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token, revision=revision
self.block_config = BloomConfig.from_pretrained(
converted_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
@ -169,16 +135,174 @@ class Server(threading.Thread):
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
raise
block_indices = range(first_block_index, last_block_index)
else:
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
time.sleep(random.random() * max_block_selection_delay)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
self.min_balance_quality = min_balance_quality
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.stop = threading.Event()
if start:
self.start()
def run(self):
while True:
block_indices = self._choose_blocks()
self.module_container = ModuleContainer.create(
dht=self.dht,
prefix=self.prefix,
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
memory_cache=self.memory_cache,
throughput=self.throughput,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
device=self.device,
compression=self.compression,
stats_report_interval=self.stats_report_interval,
update_period=self.update_period,
expiration=self.expiration,
prefetch_batches=self.prefetch_batches,
sender_threads=self.sender_threads,
use_auth_token=self.use_auth_token,
load_in_8bit=self.load_in_8bit,
start=True,
)
try:
self.module_container.ready.wait()
while True:
timeout = random.random() * 2 * self.mean_balance_check_period
# TODO: Follow ModuleContainer status (to restart/stop if it crashes)
if self.stop.wait(timeout):
return
if self._should_choose_other_blocks():
logger.info("Swarm is imbalanced, server will load other blocks")
break # Stop serving this set of modules
finally:
self.module_container.shutdown()
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
return self.strict_block_indices
assert self.num_blocks is not None
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
return block_selection.choose_best_blocks(self.num_blocks, module_infos)
def _should_choose_other_blocks(self) -> bool:
if self.strict_block_indices is not None:
return False
module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
def shutdown(self):
self.stop.set()
self.dht.shutdown()
self.dht.join()
class ModuleContainer(threading.Thread):
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int,
throughput: float,
update_period: float,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends,
dht,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
assert num_blocks is not None
uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
block_indices = choose_best_blocks(num_blocks, module_infos)
if start:
self.run_in_background(await_ready=True)
def run(self):
"""
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
if self.module_backends:
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for handler in self.conn_handlers:
handler.run_in_background()
self.runtime.run()
# noinspection PyMethodOverriding
@classmethod
def create(
cls,
*,
dht: DHT,
prefix: str,
converted_model_name_or_path: str,
block_config: BloomConfig,
memory_cache: MemoryCache,
throughput: float,
block_indices: List[int],
num_handlers: Optional[int],
min_batch_size: int,
max_batch_size: int,
inference_max_length: int,
torch_dtype: torch.dtype,
cache_dir: Optional[str],
device: Union[str, torch.device],
compression: CompressionType,
stats_report_interval: Optional[int],
update_period: float,
expiration: Optional[float],
prefetch_batches: int,
sender_threads: int,
use_auth_token: Optional[str],
load_in_8bit: bool,
start: bool,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
declare_active_modules(
dht,
@ -245,33 +369,36 @@ class Server(threading.Thread):
def run_in_background(self, await_ready=True, timeout=None):
"""
Starts Server in a background thread. if await_ready, this method will wait until background server
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
is ready to process incoming requests or for :timeout: seconds max.
"""
self.start()
if await_ready and not self.ready.wait(timeout=timeout):
raise TimeoutError("Server didn't notify .ready in {timeout} seconds")
raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
@property
def ready(self) -> mp.synchronize.Event:
"""
An event (multiprocessing.Event) that is set when the server is ready to process requests.
An event (multiprocessing.Event) that is set when the container is ready to process requests.
Example
=======
>>> server.start()
>>> server.ready.wait(timeout=10)
>>> print("Server ready" if server.ready.is_set() else "Server didn't start in 10 seconds")
>>> container.start()
>>> container.ready.wait(timeout=10)
>>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
"""
return self.runtime.ready # mp.Event that is true if self is ready to process batches
def shutdown(self):
"""
Gracefully terminate the server, process-safe.
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
Gracefully terminate the container, process-safe.
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
if self.module_backends:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()
declare_active_modules(
self.dht,
self.module_backends.keys(),
@ -283,30 +410,22 @@ class Server(threading.Thread):
self.ready.clear()
for process in self.conn_handlers:
process.terminate()
process.join()
for handler in self.conn_handlers:
handler.shutdown()
logger.debug("Connection handlers terminated")
if self.module_backends:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()
if self.checkpoint_saver is not None:
self.checkpoint_saver.stop.set()
self.checkpoint_saver.join()
self.dht.shutdown()
self.dht.join()
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()
logger.info("Server shut down succesfully")
logger.info("Module container shut down succesfully")
class ModuleAnnouncerThread(threading.Thread):
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
"""Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
def __init__(
self,

Loading…
Cancel
Save