Fix race condition in MemoryCache (#487)

pull/489/head
Alexander Borzunov 9 months ago committed by GitHub
parent dc0072fde1
commit 02fc71eb25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,7 @@ class MemoryCache:
self.max_alloc_timeout = max_alloc_timeout self.max_alloc_timeout = max_alloc_timeout
self._lock_metadata = mp.Lock() self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False) self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
self.runtime_pid = os.getpid() self.runtime_pid = os.getpid()
@ -138,7 +138,8 @@ class MemoryCache:
start_time = time.perf_counter() start_time = time.perf_counter()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.enqueued_size_bytes += alloc_size with self._enqueued_size.get_lock():
self._enqueued_size.value += alloc_size
allocated = False allocated = False
try: try:
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack() context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
@ -155,13 +156,15 @@ class MemoryCache:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout) await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
allocated = True allocated = True
self.enqueued_size_bytes -= alloc_size with self._enqueued_size.get_lock():
self._enqueued_size.value -= alloc_size
yield yield
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds") raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
finally: finally:
if not allocated: if not allocated:
self.enqueued_size_bytes -= alloc_size with self._enqueued_size.get_lock():
self._enqueued_size.value -= alloc_size
def _free(self, alloc_size: int, alloc_task: asyncio.Task): def _free(self, alloc_size: int, alloc_task: asyncio.Task):
if alloc_task.exception() is not None: if alloc_task.exception() is not None:

Loading…
Cancel
Save