diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 6fb3895..fa4db21 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -31,7 +31,7 @@ class MemoryCache: self.max_alloc_timeout = max_alloc_timeout self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) - self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False) + self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self.runtime_pid = os.getpid() @@ -138,7 +138,8 @@ class MemoryCache: start_time = time.perf_counter() loop = asyncio.get_event_loop() - self.enqueued_size_bytes += alloc_size + with self._enqueued_size.get_lock(): + self._enqueued_size.value += alloc_size allocated = False try: context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack() @@ -155,13 +156,15 @@ class MemoryCache: await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout) allocated = True - self.enqueued_size_bytes -= alloc_size + with self._enqueued_size.get_lock(): + self._enqueued_size.value -= alloc_size yield except asyncio.TimeoutError: raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds") finally: if not allocated: - self.enqueued_size_bytes -= alloc_size + with self._enqueued_size.get_lock(): + self._enqueued_size.value -= alloc_size def _free(self, alloc_size: int, alloc_task: asyncio.Task): if alloc_task.exception() is not None: