From 73df69a11771b5d6c7e34c763ac49aec9e6567ae Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 15 Dec 2022 00:11:46 +0400 Subject: [PATCH] Reset MemoryCache during rebalancings (#154) Before this PR, if there were open inference sessions right when rebalancing is triggered, their cache was never properly destroyed. --- src/petals/server/memory_cache.py | 12 ++---------- src/petals/server/server.py | 28 ++++++++++++++++++++-------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index d4a50f3..0410069 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -33,12 +33,10 @@ class MemoryCache: self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) - self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None - self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None + self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self.runtime_pid = os.getpid() self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime - self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False) self._lock_acquire_memory = mp.Lock() self._memory_freed_event = mp.Event() @@ -83,14 +81,12 @@ class MemoryCache: allocated_handle = int(self.handle_counter) self.current_size_bytes += allocated_size_bytes self.handle_counter += 1 # note: this will eventually overflow and it is okay - self._pending_messages.value += 1 self._pipe_send.send((allocated_handle, descr)) yield allocated_handle finally: if allocated_handle is not None: async with hivemind.utils.enter_asynchronously(self._lock_metadata): - self._pending_messages.value += 1 self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle self.current_size_bytes -= allocated_size_bytes self._memory_freed_event.set() @@ -122,13 +118,9 @@ class MemoryCache: # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here with self._lock_metadata: - if self._allocated_tensors is None: - self._allocated_tensors = {} - # read creation/deletion requests from connection handlers - for i in range(int(self._pending_messages.value)): + while self._pipe_recv.poll(): recv_handle, recv_data = self._pipe_recv.recv() - self._pending_messages.value -= 1 if isinstance(recv_data, TensorDescriptor): self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device) elif recv_data is None: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 22b5652..ee0b02e 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc +import itertools import math import multiprocessing as mp import random @@ -72,8 +73,8 @@ class Server: prefetch_batches: int = 1, sender_threads: int = 1, balance_quality: float = 0.75, - mean_balance_check_period: float = 60, - mean_block_selection_delay: float = 0.5, + mean_balance_check_period: float = 120, + mean_block_selection_delay: float = 2.5, use_auth_token: Optional[str] = None, load_in_8bit: Optional[bool] = None, **kwargs, @@ -119,7 +120,6 @@ class Server: visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()] if initial_peers == PUBLIC_INITIAL_PEERS: logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}") - logger.info("Please check that your server is reachable at http://health.petals.ml") else: logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") @@ -157,8 +157,8 @@ class Server: if attn_cache_size is None: # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336 + self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB") - self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR @@ -211,7 +211,8 @@ class Server: prefix=self.prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, - memory_cache=self.memory_cache, + attn_cache_size=self.attn_cache_size, + alloc_timeout=self.alloc_timeout, throughput=self.throughput, block_indices=block_indices, num_handlers=self.num_handlers, @@ -310,7 +311,8 @@ class ModuleContainer(threading.Thread): prefix: str, converted_model_name_or_path: str, block_config: BloomConfig, - memory_cache: MemoryCache, + attn_cache_size: int, + alloc_timeout: float, throughput: float, block_indices: List[int], min_batch_size: int, @@ -339,8 +341,9 @@ class ModuleContainer(threading.Thread): joining_announcer.start() logger.info(f"Announced that blocks {block_indices} are joining") + memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) + blocks = {} try: - blocks = {} for module_uid, block_index in zip(module_uids, block_indices): block = load_pretrained_block( converted_model_name_or_path, @@ -380,6 +383,10 @@ class ModuleContainer(threading.Thread): max_batch_size=max_batch_size, ) except: + logger.debug("Shutting down backends") + for backend in blocks.values(): + backend.shutdown() + joining_announcer.stop.set() joining_announcer.join() declare_active_modules( @@ -563,7 +570,7 @@ class ModuleAnnouncerThread(threading.Thread): self.stop = threading.Event() def run(self) -> None: - while True: + for iter_no in itertools.count(): declare_active_modules( self.dht, self.module_uids, @@ -571,5 +578,10 @@ class ModuleAnnouncerThread(threading.Thread): state=self.state, throughput=self.throughput, ) + if iter_no == 0 and self.state == ServerState.JOINING: + logger.info( + f"Please ensure that your server is reachable. " + f"For public swarm, open http://health.petals.ml and find peer_id = {self.dht.peer_id}" + ) if self.stop.wait(self.update_period): break