mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Make ServerState announcements work better (#93)
- Before this PR, `ServerState.JOINING` was announced only once. This announcement quickly expires in case of the full-size BLOOM, since loading blocks takes several minutes. This PR fixes it, so `ServerState.JOINING` is announced periodically in a thread until blocks are loaded. - This PR also makes the `Server` class a non-thread, so it runs in the main thread and can catch `KeyboardInterrupt`. This is important, since if we are downloading blocks right now, we need to stop it and send the `ServerState.OFFLINE` message. Note that `ModuleContainer` is still a thread. - (minor) For the sake of readability, I moved the `ModuleContainer.create()` definition, so it is now defined before `Server.__init__()` (this is because `.create()` is invoked first).
This commit is contained in:
parent
dc71574a63
commit
8a73b41a42
@ -124,10 +124,9 @@ def main():
|
|||||||
use_auth_token = args.pop("use_auth_token")
|
use_auth_token = args.pop("use_auth_token")
|
||||||
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
|
||||||
|
|
||||||
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
|
server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
server.join()
|
server.run()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Caught KeyboardInterrupt, shutting down")
|
logger.info("Caught KeyboardInterrupt, shutting down")
|
||||||
finally:
|
finally:
|
||||||
|
@ -32,7 +32,7 @@ use_hivemind_log_handler("in_root_logger")
|
|||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class Server(threading.Thread):
|
class Server:
|
||||||
"""
|
"""
|
||||||
Runs ModuleContainer, periodically checks that the network is balanced,
|
Runs ModuleContainer, periodically checks that the network is balanced,
|
||||||
restarts the ModuleContainer with other layers if the imbalance is significant
|
restarts the ModuleContainer with other layers if the imbalance is significant
|
||||||
@ -68,13 +68,10 @@ class Server(threading.Thread):
|
|||||||
mean_block_selection_delay: float = 0.5,
|
mean_block_selection_delay: float = 0.5,
|
||||||
use_auth_token: Optional[str] = None,
|
use_auth_token: Optional[str] = None,
|
||||||
load_in_8bit: bool = False,
|
load_in_8bit: bool = False,
|
||||||
start: bool,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.converted_model_name_or_path = converted_model_name_or_path
|
self.converted_model_name_or_path = converted_model_name_or_path
|
||||||
self.num_handlers = num_handlers
|
self.num_handlers = num_handlers
|
||||||
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
||||||
@ -147,8 +144,6 @@ class Server(threading.Thread):
|
|||||||
self.mean_block_selection_delay = mean_block_selection_delay
|
self.mean_block_selection_delay = mean_block_selection_delay
|
||||||
|
|
||||||
self.stop = threading.Event()
|
self.stop = threading.Event()
|
||||||
if start:
|
|
||||||
self.start()
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
@ -231,65 +226,6 @@ class Server(threading.Thread):
|
|||||||
class ModuleContainer(threading.Thread):
|
class ModuleContainer(threading.Thread):
|
||||||
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
|
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dht: DHT,
|
|
||||||
module_backends: Dict[str, TransformerBackend],
|
|
||||||
*,
|
|
||||||
inference_max_length: int,
|
|
||||||
num_connection_handlers: int,
|
|
||||||
throughput: float,
|
|
||||||
update_period: float,
|
|
||||||
expiration: Optional[float] = None,
|
|
||||||
start: bool,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dht, self.module_backends = dht, module_backends
|
|
||||||
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
||||||
self.conn_handlers = [
|
|
||||||
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
|
|
||||||
for _ in range(num_connection_handlers)
|
|
||||||
]
|
|
||||||
self.runtime = Runtime(self.module_backends, **kwargs)
|
|
||||||
self.dht_handler_thread = ModuleAnnouncerThread(
|
|
||||||
self.module_backends,
|
|
||||||
dht,
|
|
||||||
throughput=throughput,
|
|
||||||
update_period=update_period,
|
|
||||||
expiration=expiration,
|
|
||||||
daemon=True,
|
|
||||||
)
|
|
||||||
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
||||||
|
|
||||||
if start:
|
|
||||||
self.run_in_background(await_ready=True)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
"""
|
|
||||||
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
||||||
runs Runtime (self.runtime) to process incoming requests.
|
|
||||||
"""
|
|
||||||
logger.info(f"Serving {len(self.module_backends)} blocks:")
|
|
||||||
for expert_name, backend in self.module_backends.items():
|
|
||||||
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
|
||||||
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
|
|
||||||
|
|
||||||
if not self.dht.is_alive():
|
|
||||||
self.dht.run_in_background(await_ready=True)
|
|
||||||
|
|
||||||
if self.module_backends:
|
|
||||||
self.dht_handler_thread.start()
|
|
||||||
|
|
||||||
if self.checkpoint_saver is not None:
|
|
||||||
self.checkpoint_saver.start()
|
|
||||||
|
|
||||||
for handler in self.conn_handlers:
|
|
||||||
handler.run_in_background()
|
|
||||||
|
|
||||||
self.runtime.run()
|
|
||||||
|
|
||||||
# noinspection PyMethodOverriding
|
# noinspection PyMethodOverriding
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@ -320,53 +256,72 @@ class ModuleContainer(threading.Thread):
|
|||||||
start: bool,
|
start: bool,
|
||||||
) -> ModuleContainer:
|
) -> ModuleContainer:
|
||||||
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
||||||
declare_active_modules(
|
joining_announcer = ModuleAnnouncerThread(
|
||||||
dht,
|
|
||||||
module_uids,
|
module_uids,
|
||||||
expiration_time=get_dht_time() + expiration,
|
dht,
|
||||||
state=ServerState.JOINING,
|
ServerState.JOINING,
|
||||||
throughput=throughput,
|
throughput=throughput,
|
||||||
|
update_period=update_period,
|
||||||
|
expiration=expiration,
|
||||||
|
daemon=True,
|
||||||
)
|
)
|
||||||
|
joining_announcer.start()
|
||||||
logger.info(f"Announced that blocks {block_indices} are joining")
|
logger.info(f"Announced that blocks {block_indices} are joining")
|
||||||
|
|
||||||
blocks = {}
|
try:
|
||||||
for module_uid, block_index in zip(module_uids, block_indices):
|
blocks = {}
|
||||||
block = load_pretrained_block(
|
for module_uid, block_index in zip(module_uids, block_indices):
|
||||||
converted_model_name_or_path,
|
block = load_pretrained_block(
|
||||||
block_index,
|
converted_model_name_or_path,
|
||||||
block_config,
|
block_index,
|
||||||
torch_dtype=torch_dtype,
|
block_config,
|
||||||
use_auth_token=use_auth_token,
|
torch_dtype=torch_dtype,
|
||||||
cache_dir=cache_dir,
|
use_auth_token=use_auth_token,
|
||||||
)
|
cache_dir=cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
if load_in_8bit:
|
if load_in_8bit:
|
||||||
dtype = block.input_layernorm.weight.dtype
|
dtype = block.input_layernorm.weight.dtype
|
||||||
block = replace_8bit_linear(block)
|
block = replace_8bit_linear(block)
|
||||||
|
|
||||||
block = block.to(device)
|
block = block.to(device)
|
||||||
for param in block.parameters():
|
for param in block.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
blocks[module_uid] = TransformerBackend(
|
blocks[module_uid] = TransformerBackend(
|
||||||
module_uid,
|
module_uid,
|
||||||
block,
|
block,
|
||||||
memory_cache=memory_cache,
|
memory_cache=memory_cache,
|
||||||
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
|
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
|
||||||
args_schema=(
|
args_schema=(
|
||||||
BatchTensorDescriptor(
|
BatchTensorDescriptor(
|
||||||
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
||||||
|
),
|
||||||
),
|
),
|
||||||
),
|
kwargs_schema={},
|
||||||
kwargs_schema={},
|
outputs_schema=(
|
||||||
outputs_schema=(
|
BatchTensorDescriptor(
|
||||||
BatchTensorDescriptor(
|
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
||||||
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
|
),
|
||||||
),
|
),
|
||||||
),
|
min_batch_size=min_batch_size,
|
||||||
min_batch_size=min_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_batch_size=max_batch_size,
|
)
|
||||||
|
except:
|
||||||
|
joining_announcer.stop.set()
|
||||||
|
joining_announcer.join()
|
||||||
|
declare_active_modules(
|
||||||
|
dht,
|
||||||
|
module_uids,
|
||||||
|
expiration_time=get_dht_time() + expiration,
|
||||||
|
state=ServerState.OFFLINE,
|
||||||
|
throughput=throughput,
|
||||||
)
|
)
|
||||||
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
joining_announcer.stop.set()
|
||||||
|
joining_announcer.join()
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
dht,
|
dht,
|
||||||
@ -383,6 +338,65 @@ class ModuleContainer(threading.Thread):
|
|||||||
start=start,
|
start=start,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dht: DHT,
|
||||||
|
module_backends: Dict[str, TransformerBackend],
|
||||||
|
*,
|
||||||
|
inference_max_length: int,
|
||||||
|
num_connection_handlers: int,
|
||||||
|
throughput: float,
|
||||||
|
update_period: float,
|
||||||
|
expiration: Optional[float] = None,
|
||||||
|
start: bool,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dht, self.module_backends = dht, module_backends
|
||||||
|
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
||||||
|
self.conn_handlers = [
|
||||||
|
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
|
||||||
|
for _ in range(num_connection_handlers)
|
||||||
|
]
|
||||||
|
self.runtime = Runtime(self.module_backends, **kwargs)
|
||||||
|
self.online_announcer = ModuleAnnouncerThread(
|
||||||
|
list(self.module_backends.keys()),
|
||||||
|
dht,
|
||||||
|
ServerState.ONLINE,
|
||||||
|
throughput=throughput,
|
||||||
|
update_period=update_period,
|
||||||
|
expiration=expiration,
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
||||||
|
|
||||||
|
if start:
|
||||||
|
self.run_in_background(await_ready=True)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""
|
||||||
|
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
||||||
|
runs Runtime (self.runtime) to process incoming requests.
|
||||||
|
"""
|
||||||
|
logger.info(f"Serving {len(self.module_backends)} blocks:")
|
||||||
|
for expert_name, backend in self.module_backends.items():
|
||||||
|
num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
|
||||||
|
logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
|
||||||
|
|
||||||
|
if not self.dht.is_alive():
|
||||||
|
self.dht.run_in_background(await_ready=True)
|
||||||
|
|
||||||
|
self.online_announcer.start()
|
||||||
|
|
||||||
|
if self.checkpoint_saver is not None:
|
||||||
|
self.checkpoint_saver.start()
|
||||||
|
|
||||||
|
for handler in self.conn_handlers:
|
||||||
|
handler.run_in_background()
|
||||||
|
|
||||||
|
self.runtime.run()
|
||||||
|
|
||||||
def run_in_background(self, await_ready=True, timeout=None):
|
def run_in_background(self, await_ready=True, timeout=None):
|
||||||
"""
|
"""
|
||||||
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
|
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
|
||||||
@ -411,18 +425,17 @@ class ModuleContainer(threading.Thread):
|
|||||||
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
||||||
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
||||||
"""
|
"""
|
||||||
if self.module_backends:
|
self.online_announcer.stop.set()
|
||||||
self.dht_handler_thread.stop.set()
|
self.online_announcer.join()
|
||||||
self.dht_handler_thread.join()
|
|
||||||
|
|
||||||
declare_active_modules(
|
declare_active_modules(
|
||||||
self.dht,
|
self.dht,
|
||||||
self.module_backends.keys(),
|
self.module_backends.keys(),
|
||||||
expiration_time=get_dht_time() + self.expiration,
|
expiration_time=get_dht_time() + self.expiration,
|
||||||
state=ServerState.OFFLINE,
|
state=ServerState.OFFLINE,
|
||||||
throughput=self.throughput,
|
throughput=self.throughput,
|
||||||
)
|
)
|
||||||
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
||||||
|
|
||||||
self.ready.clear()
|
self.ready.clear()
|
||||||
|
|
||||||
@ -450,8 +463,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
module_backends: Dict[str, TransformerBackend],
|
module_uids: List[str],
|
||||||
dht: DHT,
|
dht: DHT,
|
||||||
|
state: ServerState,
|
||||||
*,
|
*,
|
||||||
throughput: float,
|
throughput: float,
|
||||||
update_period: float = 30,
|
update_period: float = 30,
|
||||||
@ -459,8 +473,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.module_backends = module_backends
|
self.module_uids = module_uids
|
||||||
self.dht = dht
|
self.dht = dht
|
||||||
|
self.state = state
|
||||||
self.throughput = throughput
|
self.throughput = throughput
|
||||||
self.update_period = update_period
|
self.update_period = update_period
|
||||||
self.expiration = expiration
|
self.expiration = expiration
|
||||||
@ -470,9 +485,9 @@ class ModuleAnnouncerThread(threading.Thread):
|
|||||||
while True:
|
while True:
|
||||||
declare_active_modules(
|
declare_active_modules(
|
||||||
self.dht,
|
self.dht,
|
||||||
self.module_backends.keys(),
|
self.module_uids,
|
||||||
expiration_time=get_dht_time() + self.expiration,
|
expiration_time=get_dht_time() + self.expiration,
|
||||||
state=ServerState.ONLINE,
|
state=self.state,
|
||||||
throughput=self.throughput,
|
throughput=self.throughput,
|
||||||
)
|
)
|
||||||
if self.stop.wait(self.update_period):
|
if self.stop.wait(self.update_period):
|
||||||
|
Loading…
Reference in New Issue
Block a user