From 6e4ebb94d2b84d8b278b328331c126811f1e0916 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 21 Jul 2023 11:09:24 +0400 Subject: [PATCH] Fix deadlocks in MemoryCache (#396) - Fix deadlocks in MemoryCache - Set default --alloc_timeout to 1 until the MemoryCache update --- src/petals/cli/run_server.py | 2 +- src/petals/server/memory_cache.py | 45 +++++++++++++------------------ 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 46b1163..a33e233 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -94,7 +94,7 @@ def main(): parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--alloc_timeout', type=float, default=5, + parser.add_argument('--alloc_timeout', type=float, default=1, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') parser.add_argument('--revision', type=str, default=None, diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index a1e2f26..c2aa192 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -90,7 +90,7 @@ class MemoryCache: logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: - await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) + self._free(max_alloc_size, alloc_task) @staticmethod def get_allocation_size(*descriptors: TensorDescriptor) -> int: @@ -111,25 +111,19 @@ class MemoryCache: async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): if self.current_size_bytes + alloc_size > self.max_size_bytes: await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) - async with hivemind.utils.enter_asynchronously(self._lock_metadata): + with self._lock_metadata: handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size self.handle_counter += len(handles) # note: this will eventually overflow and it is okay self._pipe_send.send((handles, descriptors)) return handles - async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): - """ - This method should be called inside asyncio.shield() because: - - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - - _schedule_free() must finish freeing memory even in case of cancellation - """ - + def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None: if alloc_task.exception() is not None: return handles = alloc_task.result() - async with hivemind.utils.enter_asynchronously(self._lock_metadata): + with self._lock_metadata: self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() @@ -160,22 +154,21 @@ class MemoryCache: assert os.getpid() == self.runtime_pid # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here - with self._lock_metadata: - # read creation/deletion requests from connection handlers - while self._pipe_recv.poll(): - recv_handles, recv_data = self._pipe_recv.recv() - if recv_data is not None: # create new tensors - assert len(recv_handles) == len(recv_data) - for handle, descr in zip(recv_handles, recv_data): - self._allocated_tensors[handle] = descr.make_zeros() - assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" - else: # delete tensors by handle - for handle in recv_handles: - if handle not in self._allocated_tensors: - logger.warning( - f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" - ) - self._allocated_tensors.pop(handle, None) + # read creation/deletion requests from connection handlers + while self._pipe_recv.poll(): + recv_handles, recv_data = self._pipe_recv.recv() + if recv_data is not None: # create new tensors + assert len(recv_handles) == len(recv_data) + for handle, descr in zip(recv_handles, recv_data): + self._allocated_tensors[handle] = descr.make_zeros() + assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" + else: # delete tensors by handle + for handle in recv_handles: + if handle not in self._allocated_tensors: + logger.warning( + f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" + ) + self._allocated_tensors.pop(handle, None) yield tuple(self._allocated_tensors[handle] for handle in handles)