Make ServerState announcements work better (#93)

- Before this PR, `ServerState.JOINING` was announced only once. This announcement quickly expires in case of the full-size BLOOM, since loading blocks takes several minutes. This PR fixes it, so `ServerState.JOINING` is announced periodically in a thread until blocks are loaded.

- This PR also makes the `Server` class a non-thread, so it runs in the main thread and can catch `KeyboardInterrupt`. This is important, since if we are downloading blocks right now, we need to stop it and send the `ServerState.OFFLINE` message. Note that `ModuleContainer` is still a thread.

- (minor) For the sake of readability, I moved the `ModuleContainer.create()` definition, so it is now defined before `Server.__init__()` (this is because `.create()` is invoked first).
This commit is contained in:
Alexander Borzunov 2022-11-28 07:44:03 +04:00 committed by GitHub
parent dc71574a63
commit 8a73b41a42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 133 additions and 119 deletions

View File

@ -124,10 +124,9 @@ def main():
use_auth_token = args.pop("use_auth_token") use_auth_token = args.pop("use_auth_token")
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True) server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
try: try:
server.join() server.run()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down") logger.info("Caught KeyboardInterrupt, shutting down")
finally: finally:

View File

@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
class Server(threading.Thread): class Server:
""" """
Runs ModuleContainer, periodically checks that the network is balanced, Runs ModuleContainer, periodically checks that the network is balanced,
restarts the ModuleContainer with other layers if the imbalance is significant restarts the ModuleContainer with other layers if the imbalance is significant
@ -68,13 +68,10 @@ class Server(threading.Thread):
mean_block_selection_delay: float = 0.5, mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None, use_auth_token: Optional[str] = None,
load_in_8bit: bool = False, load_in_8bit: bool = False,
start: bool,
**kwargs, **kwargs,
): ):
"""Create a server with one or more bloom blocks. See run_server.py for documentation.""" """Create a server with one or more bloom blocks. See run_server.py for documentation."""
super().__init__()
self.converted_model_name_or_path = converted_model_name_or_path self.converted_model_name_or_path = converted_model_name_or_path
self.num_handlers = num_handlers self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
@ -147,8 +144,6 @@ class Server(threading.Thread):
self.mean_block_selection_delay = mean_block_selection_delay self.mean_block_selection_delay = mean_block_selection_delay
self.stop = threading.Event() self.stop = threading.Event()
if start:
self.start()
def run(self): def run(self):
while True: while True:
@ -231,65 +226,6 @@ class Server(threading.Thread):
class ModuleContainer(threading.Thread): class ModuleContainer(threading.Thread):
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT.""" """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int,
throughput: float,
update_period: float,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends,
dht,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
self.run_in_background(await_ready=True)
def run(self):
"""
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
if self.module_backends:
self.dht_handler_thread.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for handler in self.conn_handlers:
handler.run_in_background()
self.runtime.run()
# noinspection PyMethodOverriding # noinspection PyMethodOverriding
@classmethod @classmethod
def create( def create(
@ -320,53 +256,72 @@ class ModuleContainer(threading.Thread):
start: bool, start: bool,
) -> ModuleContainer: ) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices] module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
declare_active_modules( joining_announcer = ModuleAnnouncerThread(
dht,
module_uids, module_uids,
expiration_time=get_dht_time() + expiration, dht,
state=ServerState.JOINING, ServerState.JOINING,
throughput=throughput, throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
) )
joining_announcer.start()
logger.info(f"Announced that blocks {block_indices} are joining") logger.info(f"Announced that blocks {block_indices} are joining")
blocks = {} try:
for module_uid, block_index in zip(module_uids, block_indices): blocks = {}
block = load_pretrained_block( for module_uid, block_index in zip(module_uids, block_indices):
converted_model_name_or_path, block = load_pretrained_block(
block_index, converted_model_name_or_path,
block_config, block_index,
torch_dtype=torch_dtype, block_config,
use_auth_token=use_auth_token, torch_dtype=torch_dtype,
cache_dir=cache_dir, use_auth_token=use_auth_token,
) cache_dir=cache_dir,
)
if load_in_8bit: if load_in_8bit:
dtype = block.input_layernorm.weight.dtype dtype = block.input_layernorm.weight.dtype
block = replace_8bit_linear(block) block = replace_8bit_linear(block)
block = block.to(device) block = block.to(device)
for param in block.parameters(): for param in block.parameters():
param.requires_grad = False param.requires_grad = False
blocks[module_uid] = TransformerBackend( blocks[module_uid] = TransformerBackend(
module_uid, module_uid,
block, block,
memory_cache=memory_cache, memory_cache=memory_cache,
backend_dtype=None if torch_dtype == "auto" else torch_dtype, backend_dtype=None if torch_dtype == "auto" else torch_dtype,
args_schema=( args_schema=(
BatchTensorDescriptor( BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
), ),
), kwargs_schema={},
kwargs_schema={}, outputs_schema=(
outputs_schema=( BatchTensorDescriptor(
BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression ),
), ),
), min_batch_size=min_batch_size,
min_batch_size=min_batch_size, max_batch_size=max_batch_size,
max_batch_size=max_batch_size, )
except:
joining_announcer.stop.set()
joining_announcer.join()
declare_active_modules(
dht,
module_uids,
expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE,
throughput=throughput,
) )
logger.info(f"Announced that blocks {module_uids} are offline")
raise
else:
joining_announcer.stop.set()
joining_announcer.join()
return cls( return cls(
dht, dht,
@ -383,6 +338,65 @@ class ModuleContainer(threading.Thread):
start=start, start=start,
) )
def __init__(
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int,
throughput: float,
update_period: float,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
super().__init__()
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
dht,
ServerState.ONLINE,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
self.run_in_background(await_ready=True)
def run(self):
"""
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
runs Runtime (self.runtime) to process incoming requests.
"""
logger.info(f"Serving {len(self.module_backends)} blocks:")
for expert_name, backend in self.module_backends.items():
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
self.online_announcer.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
for handler in self.conn_handlers:
handler.run_in_background()
self.runtime.run()
def run_in_background(self, await_ready=True, timeout=None): def run_in_background(self, await_ready=True, timeout=None):
""" """
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
@ -411,18 +425,17 @@ class ModuleContainer(threading.Thread):
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes. Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL). If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
""" """
if self.module_backends: self.online_announcer.stop.set()
self.dht_handler_thread.stop.set() self.online_announcer.join()
self.dht_handler_thread.join()
declare_active_modules( declare_active_modules(
self.dht, self.dht,
self.module_backends.keys(), self.module_backends.keys(),
expiration_time=get_dht_time() + self.expiration, expiration_time=get_dht_time() + self.expiration,
state=ServerState.OFFLINE, state=ServerState.OFFLINE,
throughput=self.throughput, throughput=self.throughput,
) )
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline") logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
self.ready.clear() self.ready.clear()
@ -450,8 +463,9 @@ class ModuleAnnouncerThread(threading.Thread):
def __init__( def __init__(
self, self,
module_backends: Dict[str, TransformerBackend], module_uids: List[str],
dht: DHT, dht: DHT,
state: ServerState,
*, *,
throughput: float, throughput: float,
update_period: float = 30, update_period: float = 30,
@ -459,8 +473,9 @@ class ModuleAnnouncerThread(threading.Thread):
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.module_backends = module_backends self.module_uids = module_uids
self.dht = dht self.dht = dht
self.state = state
self.throughput = throughput self.throughput = throughput
self.update_period = update_period self.update_period = update_period
self.expiration = expiration self.expiration = expiration
@ -470,9 +485,9 @@ class ModuleAnnouncerThread(threading.Thread):
while True: while True:
declare_active_modules( declare_active_modules(
self.dht, self.dht,
self.module_backends.keys(), self.module_uids,
expiration_time=get_dht_time() + self.expiration, expiration_time=get_dht_time() + self.expiration,
state=ServerState.ONLINE, state=self.state,
throughput=self.throughput, throughput=self.throughput,
) )
if self.stop.wait(self.update_period): if self.stop.wait(self.update_period):