Fix OOMs during server rebalancing (#150)

The cause of OOMs were the cyclic references `TransformerBackend <-> PrioritizedTaskPool` that could not have been garbage collected properly. Still, I've added explicit tensor removal just in case.
pull/153/head
Alexander Borzunov 1 year ago committed by GitHub
parent 83d9493b6c
commit e4dc938dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -85,3 +85,13 @@ class TransformerBackend(ModuleBackend):
def get_info(self) -> Dict[str, Any]:
"""Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)
def shutdown(self):
# Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
self.forward_pool = self.backward_pool = self.inference_pool = None
# Explicitly free the GPU memory. This is not necessary at the time this code is written,
# but may help to avoid future issues when the module is not garbage-collected for some reasons
dummy = torch.tensor([])
for p in self.module.parameters():
p.data = dummy

@ -235,8 +235,8 @@ class Server:
if self.stop.wait(timeout):
return
if not self.module_container.handlers_alive:
logger.warning("One of connection handlers crashed, restarting the server")
if not self.module_container.is_healthy():
logger.warning("One of subprocesses crashed, restarting the server")
break
if self._should_choose_other_blocks():
@ -252,8 +252,19 @@ class Server:
gc.collect() # In particular, this closes unused file descriptors
cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
if self.device.type == "cuda":
torch.cuda.empty_cache()
allocated_vram = torch.cuda.memory_allocated(self.device)
reserved_vram = torch.cuda.memory_reserved(self.device)
gib = 1024**3
logger.info(
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
@ -470,9 +481,10 @@ class ModuleContainer(threading.Thread):
"""
return self.runtime.ready # mp.Event that is true if self is ready to process batches
@property
def handlers_alive(self) -> bool:
return all(handler.is_alive() for handler in self.conn_handlers)
def is_healthy(self) -> bool:
return all(handler.is_alive() for handler in self.conn_handlers) and all(
pool.is_alive() for pool in self.runtime.pools
)
def shutdown(self):
"""
@ -510,6 +522,10 @@ class ModuleContainer(threading.Thread):
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()
logger.debug("Shutting down backends")
for backend in self.module_backends.values():
backend.shutdown()
logger.info("Module container shut down successfully")

Loading…
Cancel
Save