|
|
|
@ -31,7 +31,7 @@ class MemoryCache:
|
|
|
|
|
self.max_alloc_timeout = max_alloc_timeout
|
|
|
|
|
self._lock_metadata = mp.Lock()
|
|
|
|
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
|
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
|
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=True)
|
|
|
|
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
|
|
|
|
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
|
|
|
|
|
self.runtime_pid = os.getpid()
|
|
|
|
@ -138,7 +138,8 @@ class MemoryCache:
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
|
|
self.enqueued_size_bytes += alloc_size
|
|
|
|
|
with self._enqueued_size.get_lock():
|
|
|
|
|
self._enqueued_size.value += alloc_size
|
|
|
|
|
allocated = False
|
|
|
|
|
try:
|
|
|
|
|
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
|
|
|
|
@ -155,13 +156,15 @@ class MemoryCache:
|
|
|
|
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
|
|
|
|
|
|
|
|
|
|
allocated = True
|
|
|
|
|
self.enqueued_size_bytes -= alloc_size
|
|
|
|
|
with self._enqueued_size.get_lock():
|
|
|
|
|
self._enqueued_size.value -= alloc_size
|
|
|
|
|
yield
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
|
|
|
|
|
finally:
|
|
|
|
|
if not allocated:
|
|
|
|
|
self.enqueued_size_bytes -= alloc_size
|
|
|
|
|
with self._enqueued_size.get_lock():
|
|
|
|
|
self._enqueued_size.value -= alloc_size
|
|
|
|
|
|
|
|
|
|
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
|
|
|
if alloc_task.exception() is not None:
|
|
|
|
|