Fix deadlocks in MemoryCache (#396)

- Fix deadlocks in MemoryCache
- Set default --alloc_timeout to 1 until the MemoryCache update
pull/397/head
Alexander Borzunov 11 months ago committed by GitHub
parent b6b3ae964f
commit 6e4ebb94d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -94,7 +94,7 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. " help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.") "By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--alloc_timeout', type=float, default=5, parser.add_argument('--alloc_timeout', type=float, default=1,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request') 'before rejecting the request')
parser.add_argument('--revision', type=str, default=None, parser.add_argument('--revision', type=str, default=None,

@ -90,7 +90,7 @@ class MemoryCache:
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
yield handles yield handles
finally: finally:
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) self._free(max_alloc_size, alloc_task)
@staticmethod @staticmethod
def get_allocation_size(*descriptors: TensorDescriptor) -> int: def get_allocation_size(*descriptors: TensorDescriptor) -> int:
@ -111,25 +111,19 @@ class MemoryCache:
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes: if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata): with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors)) self._pipe_send.send((handles, descriptors))
return handles return handles
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
- _schedule_free() must finish freeing memory even in case of cancellation
"""
if alloc_task.exception() is not None: if alloc_task.exception() is not None:
return return
handles = alloc_task.result() handles = alloc_task.result()
async with hivemind.utils.enter_asynchronously(self._lock_metadata): with self._lock_metadata:
self._pipe_send.send((handles, None)) # signal runtime to free these handles self._pipe_send.send((handles, None)) # signal runtime to free these handles
self.current_size_bytes -= alloc_size self.current_size_bytes -= alloc_size
self._memory_freed_event.set() self._memory_freed_event.set()
@ -160,22 +154,21 @@ class MemoryCache:
assert os.getpid() == self.runtime_pid assert os.getpid() == self.runtime_pid
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
with self._lock_metadata: # read creation/deletion requests from connection handlers
# read creation/deletion requests from connection handlers while self._pipe_recv.poll():
while self._pipe_recv.poll(): recv_handles, recv_data = self._pipe_recv.recv()
recv_handles, recv_data = self._pipe_recv.recv() if recv_data is not None: # create new tensors
if recv_data is not None: # create new tensors assert len(recv_handles) == len(recv_data)
assert len(recv_handles) == len(recv_data) for handle, descr in zip(recv_handles, recv_data):
for handle, descr in zip(recv_handles, recv_data): self._allocated_tensors[handle] = descr.make_zeros()
self._allocated_tensors[handle] = descr.make_zeros() assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" else: # delete tensors by handle
else: # delete tensors by handle for handle in recv_handles:
for handle in recv_handles: if handle not in self._allocated_tensors:
if handle not in self._allocated_tensors: logger.warning(
logger.warning( f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" )
) self._allocated_tensors.pop(handle, None)
self._allocated_tensors.pop(handle, None)
yield tuple(self._allocated_tensors[handle] for handle in handles) yield tuple(self._allocated_tensors[handle] for handle in handles)

Loading…
Cancel
Save