Announce JOINING periodically

fix-joining-announce
Aleksandr Borzunov 2 years ago
parent dc71574a63
commit 06a8246ae9

@ -231,6 +231,106 @@ class Server(threading.Thread):
class ModuleContainer(threading.Thread):
"""Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
# noinspection PyMethodOverriding
@classmethod
def create(
cls,
*,
dht: DHT,
prefix: str,
converted_model_name_or_path: str,
block_config: BloomConfig,
memory_cache: MemoryCache,
throughput: float,
block_indices: List[int],
num_handlers: Optional[int],
min_batch_size: int,
max_batch_size: int,
inference_max_length: int,
torch_dtype: torch.dtype,
cache_dir: Optional[str],
device: Union[str, torch.device],
compression: CompressionType,
stats_report_interval: Optional[int],
update_period: float,
expiration: Optional[float],
prefetch_batches: int,
sender_threads: int,
use_auth_token: Optional[str],
load_in_8bit: bool,
start: bool,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
joining_announcer = ModuleAnnouncerThread(
module_uids,
dht,
ServerState.JOINING,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
joining_announcer.start()
logger.info(f"Announced that blocks {block_indices} are joining")
try:
blocks = {}
for module_uid, block_index in zip(module_uids, block_indices):
block = load_pretrained_block(
converted_model_name_or_path,
block_index,
block_config,
torch_dtype=torch_dtype,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
if load_in_8bit:
dtype = block.input_layernorm.weight.dtype
block = replace_8bit_linear(block)
block = block.to(device)
for param in block.parameters():
param.requires_grad = False
blocks[module_uid] = TransformerBackend(
module_uid,
block,
memory_cache=memory_cache,
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
),
kwargs_schema={},
outputs_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
),
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)
finally:
joining_announcer.stop.set()
joining_announcer.join()
return cls(
dht,
blocks,
throughput=throughput,
num_connection_handlers=num_handlers,
inference_max_length=inference_max_length,
device=device,
stats_report_interval=stats_report_interval,
update_period=update_period,
expiration=expiration,
prefetch_batches=prefetch_batches,
sender_threads=sender_threads,
start=start,
)
def __init__(
self,
dht: DHT,
@ -253,9 +353,10 @@ class ModuleContainer(threading.Thread):
for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends,
self.online_announcer = ModuleAnnouncerThread(
list(self.module_backends.keys()),
dht,
ServerState.ONLINE,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -279,8 +380,7 @@ class ModuleContainer(threading.Thread):
if not self.dht.is_alive():
self.dht.run_in_background(await_ready=True)
if self.module_backends:
self.dht_handler_thread.start()
self.online_announcer.start()
if self.checkpoint_saver is not None:
self.checkpoint_saver.start()
@ -290,99 +390,6 @@ class ModuleContainer(threading.Thread):
self.runtime.run()
# noinspection PyMethodOverriding
@classmethod
def create(
cls,
*,
dht: DHT,
prefix: str,
converted_model_name_or_path: str,
block_config: BloomConfig,
memory_cache: MemoryCache,
throughput: float,
block_indices: List[int],
num_handlers: Optional[int],
min_batch_size: int,
max_batch_size: int,
inference_max_length: int,
torch_dtype: torch.dtype,
cache_dir: Optional[str],
device: Union[str, torch.device],
compression: CompressionType,
stats_report_interval: Optional[int],
update_period: float,
expiration: Optional[float],
prefetch_batches: int,
sender_threads: int,
use_auth_token: Optional[str],
load_in_8bit: bool,
start: bool,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
declare_active_modules(
dht,
module_uids,
expiration_time=get_dht_time() + expiration,
state=ServerState.JOINING,
throughput=throughput,
)
logger.info(f"Announced that blocks {block_indices} are joining")
blocks = {}
for module_uid, block_index in zip(module_uids, block_indices):
block = load_pretrained_block(
converted_model_name_or_path,
block_index,
block_config,
torch_dtype=torch_dtype,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
)
if load_in_8bit:
dtype = block.input_layernorm.weight.dtype
block = replace_8bit_linear(block)
block = block.to(device)
for param in block.parameters():
param.requires_grad = False
blocks[module_uid] = TransformerBackend(
module_uid,
block,
memory_cache=memory_cache,
backend_dtype=None if torch_dtype == "auto" else torch_dtype,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
),
kwargs_schema={},
outputs_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
),
),
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
)
return cls(
dht,
blocks,
throughput=throughput,
num_connection_handlers=num_handlers,
inference_max_length=inference_max_length,
device=device,
stats_report_interval=stats_report_interval,
update_period=update_period,
expiration=expiration,
prefetch_batches=prefetch_batches,
sender_threads=sender_threads,
start=start,
)
def run_in_background(self, await_ready=True, timeout=None):
"""
Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
@ -411,18 +418,17 @@ class ModuleContainer(threading.Thread):
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
if self.module_backends:
self.dht_handler_thread.stop.set()
self.dht_handler_thread.join()
self.online_announcer.stop.set()
self.online_announcer.join()
declare_active_modules(
self.dht,
self.module_backends.keys(),
expiration_time=get_dht_time() + self.expiration,
state=ServerState.OFFLINE,
throughput=self.throughput,
)
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
declare_active_modules(
self.dht,
self.module_backends.keys(),
expiration_time=get_dht_time() + self.expiration,
state=ServerState.OFFLINE,
throughput=self.throughput,
)
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
self.ready.clear()
@ -450,8 +456,9 @@ class ModuleAnnouncerThread(threading.Thread):
def __init__(
self,
module_backends: Dict[str, TransformerBackend],
module_uids: List[str],
dht: DHT,
state: ServerState,
*,
throughput: float,
update_period: float = 30,
@ -459,8 +466,9 @@ class ModuleAnnouncerThread(threading.Thread):
**kwargs,
):
super().__init__(**kwargs)
self.module_backends = module_backends
self.module_uids = module_uids
self.dht = dht
self.state = state
self.throughput = throughput
self.update_period = update_period
self.expiration = expiration
@ -470,9 +478,9 @@ class ModuleAnnouncerThread(threading.Thread):
while True:
declare_active_modules(
self.dht,
self.module_backends.keys(),
self.module_uids,
expiration_time=get_dht_time() + self.expiration,
state=ServerState.ONLINE,
state=self.state,
throughput=self.throughput,
)
if self.stop.wait(self.update_period):

Loading…
Cancel
Save