|
|
|
@ -90,7 +90,7 @@ class MemoryCache:
|
|
|
|
|
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
|
yield handles
|
|
|
|
|
finally:
|
|
|
|
|
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
|
|
|
|
self._free(max_alloc_size, alloc_task)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
|
|
@ -111,25 +111,19 @@ class MemoryCache:
|
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
|
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
|
|
with self._lock_metadata:
|
|
|
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
|
|
self.current_size_bytes += alloc_size
|
|
|
|
|
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
|
|
self._pipe_send.send((handles, descriptors))
|
|
|
|
|
return handles
|
|
|
|
|
|
|
|
|
|
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
|
|
|
"""
|
|
|
|
|
This method should be called inside asyncio.shield() because:
|
|
|
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
|
|
|
- _schedule_free() must finish freeing memory even in case of cancellation
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
|
|
|
|
|
if alloc_task.exception() is not None:
|
|
|
|
|
return
|
|
|
|
|
handles = alloc_task.result()
|
|
|
|
|
|
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
|
|
with self._lock_metadata:
|
|
|
|
|
self._pipe_send.send((handles, None)) # signal runtime to free these handles
|
|
|
|
|
self.current_size_bytes -= alloc_size
|
|
|
|
|
self._memory_freed_event.set()
|
|
|
|
@ -160,7 +154,6 @@ class MemoryCache:
|
|
|
|
|
assert os.getpid() == self.runtime_pid
|
|
|
|
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
|
|
|
|
|
|
|
|
|
with self._lock_metadata:
|
|
|
|
|
# read creation/deletion requests from connection handlers
|
|
|
|
|
while self._pipe_recv.poll():
|
|
|
|
|
recv_handles, recv_data = self._pipe_recv.recv()
|
|
|
|
|