@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
For now , the only purpose of this code is to ensure that allocated memory will be deleted properly .
"""
import asyncio
import contextlib
import ctypes
import multiprocessing as mp
import os
import time
from typing import AsyncContextManager , Dict , Optional , Union
import hivemind
@ -27,7 +29,7 @@ class MemoryCache:
def __init__ ( self , device : Union [ str , torch . device ] , max_size_bytes : Optional [ int ] ) :
self . max_size_bytes = max_size_bytes if max_size_bytes is not None else ( 2 * * 64 - 1 )
self . device = device
self . lock_metadata, self . size_decreased_event = mp . Lock ( ) , mp . Event ( )
self . _ lock_metadata, self . size_decreased_event = mp . Lock ( ) , mp . Event ( )
self . _current_size = mp . Value ( ctypes . c_int64 , 0 , lock = False )
self . _handle_counter = mp . Value ( ctypes . c_int64 , 0 , lock = False )
self . _active_handles : Optional [ Dict [ Handle , TensorDescriptor ] ] = None
@ -36,6 +38,8 @@ class MemoryCache:
self . _pipe_recv , self . _pipe_send = mp . Pipe ( duplex = False ) # any ConnectionHandler -> runtime
self . _pending_messages = mp . Value ( ctypes . c_int64 , 0 , lock = False )
self . _lock_acquire_memory = mp . Lock ( )
self . _memory_freed_event = mp . Event ( )
@property
def current_size_bytes ( self ) - > int :
@ -67,27 +71,39 @@ class MemoryCache:
assert descr . device is None and descr
allocated_handle = None
allocated_size_bytes = descr . numel ( ) * torch . finfo ( descr . dtype ) . bits / / 8
loop = asyncio . get_event_loop ( )
try :
async with hivemind . utils . enter_asynchronously ( self . lock_metadata ) :
async with hivemind . utils . enter_asynchronously ( self . _lock_acquire_memory ) :
if self . current_size_bytes + allocated_size_bytes > self . max_size_bytes :
raise AllocationFailed (
f " Could not allocate { allocated_size_bytes } bytes in cache; cache size = "
f " { self . max_size_bytes } bytes; { self . current_size_bytes } already allocated. "
)
allocated_handle = int ( self . handle_counter )
self . current_size_bytes + = allocated_size_bytes
self . handle_counter + = 1 # note: this will eventually overflow and it is okay
self . _pending_messages . value + = 1
self . _pipe_send . send ( ( allocated_handle , descr ) )
await loop . run_in_executor ( None , self . _wait_until_available , allocated_size_bytes )
async with hivemind . utils . enter_asynchronously ( self . _lock_metadata ) :
allocated_handle = int ( self . handle_counter )
self . current_size_bytes + = allocated_size_bytes
self . handle_counter + = 1 # note: this will eventually overflow and it is okay
self . _pending_messages . value + = 1
self . _pipe_send . send ( ( allocated_handle , descr ) )
yield allocated_handle
finally :
if allocated_handle is not None :
async with hivemind . utils . enter_asynchronously ( self . lock_metadata) :
async with hivemind . utils . enter_asynchronously ( self . _ lock_metadata) :
self . _pending_messages . value + = 1
self . _pipe_send . send ( ( allocated_handle , None ) ) # signal runtime to free that handle
self . current_size_bytes - = allocated_size_bytes
self . _memory_freed_event . set ( )
def _wait_until_available ( self , allocated_size_bytes : int , timeout : Optional [ float ] = None ) :
# note: this function should only be called inside _lock_acquire_memory!
if allocated_size_bytes > self . max_size_bytes :
raise AllocationFailed (
f " Could not allocate { allocated_size_bytes } bytes, max cache size = { self . max_size_bytes } bytes "
)
deadline = None if timeout is None else time . perf_counter ( ) + timeout
while self . current_size_bytes + allocated_size_bytes > self . max_size_bytes :
remaining_time = deadline - time . perf_counter ( ) if timeout is not None else None
if not self . _memory_freed_event . wait ( remaining_time ) :
raise AllocationFailed ( f " Could not allocate { allocated_size_bytes } bytes in { timeout } seconds " )
self . _memory_freed_event . clear ( )
@contextlib.contextmanager
def use_cache ( self , handle : Handle ) - > torch . Tensor :
@ -100,7 +116,7 @@ class MemoryCache:
assert os . getpid ( ) == self . runtime_pid
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
with self . lock_metadata:
with self . _ lock_metadata:
if self . _allocated_tensors is None :
self . _allocated_tensors = { }