|
|
@ -90,7 +90,7 @@ class MemoryCache:
|
|
|
|
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
|
|
|
yield handles
|
|
|
|
yield handles
|
|
|
|
finally:
|
|
|
|
finally:
|
|
|
|
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
|
|
|
self._free(max_alloc_size, alloc_task)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
|
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
|
|
@ -111,25 +111,19 @@ class MemoryCache:
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
|
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
|
|
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
|
with self._lock_metadata:
|
|
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
|
|
|
self.current_size_bytes += alloc_size
|
|
|
|
self.current_size_bytes += alloc_size
|
|
|
|
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
|
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
|
|
|
|
self._pipe_send.send((handles, descriptors))
|
|
|
|
self._pipe_send.send((handles, descriptors))
|
|
|
|
return handles
|
|
|
|
return handles
|
|
|
|
|
|
|
|
|
|
|
|
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
|
|
|
|
def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
|
|
|
|
"""
|
|
|
|
|
|
|
|
This method should be called inside asyncio.shield() because:
|
|
|
|
|
|
|
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
|
|
|
|
|
|
|
- _schedule_free() must finish freeing memory even in case of cancellation
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if alloc_task.exception() is not None:
|
|
|
|
if alloc_task.exception() is not None:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
handles = alloc_task.result()
|
|
|
|
handles = alloc_task.result()
|
|
|
|
|
|
|
|
|
|
|
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
|
|
|
with self._lock_metadata:
|
|
|
|
self._pipe_send.send((handles, None)) # signal runtime to free these handles
|
|
|
|
self._pipe_send.send((handles, None)) # signal runtime to free these handles
|
|
|
|
self.current_size_bytes -= alloc_size
|
|
|
|
self.current_size_bytes -= alloc_size
|
|
|
|
self._memory_freed_event.set()
|
|
|
|
self._memory_freed_event.set()
|
|
|
@ -160,22 +154,21 @@ class MemoryCache:
|
|
|
|
assert os.getpid() == self.runtime_pid
|
|
|
|
assert os.getpid() == self.runtime_pid
|
|
|
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
|
|
|
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
|
|
|
|
|
|
|
|
|
|
|
|
with self._lock_metadata:
|
|
|
|
# read creation/deletion requests from connection handlers
|
|
|
|
# read creation/deletion requests from connection handlers
|
|
|
|
while self._pipe_recv.poll():
|
|
|
|
while self._pipe_recv.poll():
|
|
|
|
recv_handles, recv_data = self._pipe_recv.recv()
|
|
|
|
recv_handles, recv_data = self._pipe_recv.recv()
|
|
|
|
if recv_data is not None: # create new tensors
|
|
|
|
if recv_data is not None: # create new tensors
|
|
|
|
assert len(recv_handles) == len(recv_data)
|
|
|
|
assert len(recv_handles) == len(recv_data)
|
|
|
|
for handle, descr in zip(recv_handles, recv_data):
|
|
|
|
for handle, descr in zip(recv_handles, recv_data):
|
|
|
|
self._allocated_tensors[handle] = descr.make_zeros()
|
|
|
|
self._allocated_tensors[handle] = descr.make_zeros()
|
|
|
|
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
|
|
|
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
|
|
|
|
else: # delete tensors by handle
|
|
|
|
else: # delete tensors by handle
|
|
|
|
for handle in recv_handles:
|
|
|
|
for handle in recv_handles:
|
|
|
|
if handle not in self._allocated_tensors:
|
|
|
|
if handle not in self._allocated_tensors:
|
|
|
|
logger.warning(
|
|
|
|
logger.warning(
|
|
|
|
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
|
|
|
|
f"Sanity check failed: asked to delete handle {handle}, but there is no such handle"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
self._allocated_tensors.pop(handle, None)
|
|
|
|
self._allocated_tensors.pop(handle, None)
|
|
|
|
|
|
|
|
yield tuple(self._allocated_tensors[handle] for handle in handles)
|
|
|
|
yield tuple(self._allocated_tensors[handle] for handle in handles)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|