""" A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime. For now, the only purpose of this code is to ensure that allocated memory will be deleted properly. """ import asyncio import contextlib import ctypes import multiprocessing as mp import os import time from typing import AsyncContextManager, Dict, Optional, Sequence import hivemind import torch from hivemind.utils import TensorDescriptor, get_logger from petals.utils.asyncio import shield_and_wait from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) Handle = int class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" def __init__(self, max_size_bytes: Optional[int]): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self.runtime_pid = os.getpid() self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime self._lock_acquire_memory = mp.Lock() self._memory_freed_event = mp.Event() @property def current_size_bytes(self) -> int: return self._current_size.value @current_size_bytes.setter def current_size_bytes(self, value: int): self._current_size.value = value @property def bytes_left(self) -> int: return self.max_size_bytes - self.current_size_bytes @property def handle_counter(self) -> int: return self._handle_counter.value @handle_counter.setter def handle_counter(self, value: int): self._handle_counter.value = value @contextlib.asynccontextmanager async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. :param descriptors: one or more tensors tensor of this size, dtype, etc :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; if not, it will count maximum tensor allocation across devices for the purposes of size limit :note: This function should be called by connection handlers, it can be called concurrently from multiple processes. Furthermore, it can be called concurrently with at most one use_cache call in runtime. """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" max_alloc_size = self.get_allocation_size(*descriptors) gib = 1 cur_size, max_size = self.current_size_bytes, self.max_size_bytes friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" logger.info( f"rpc_inference.wait_for_alloc(size={max_alloc_size / gib:.2f} GiB), " f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" ) alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout)) try: handles = await shield_and_wait(alloc_task) logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)") await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)") @staticmethod def get_allocation_size(*descriptors: TensorDescriptor) -> int: """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" alloc_size_by_device = {} for descr in descriptors: tensor_size = descr.numel() * get_size_in_bytes(descr.dtype) alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size return max(alloc_size_by_device.values()) async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation """ loop = asyncio.get_event_loop() async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): if self.current_size_bytes + alloc_size > self.max_size_bytes: await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout) async with hivemind.utils.enter_asynchronously(self._lock_metadata): handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size self.handle_counter += len(handles) # note: this will eventually overflow and it is okay self._pipe_send.send((handles, descriptors)) return handles async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - _schedule_free() must finish freeing memory even in case of cancellation """ if alloc_task.exception() is not None: return handles = alloc_task.result() async with hivemind.utils.enter_asynchronously(self._lock_metadata): self._pipe_send.send((handles, None)) # signal runtime to free these handles self.current_size_bytes -= alloc_size self._memory_freed_event.set() def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None): # note: this function should only be called inside _lock_acquire_memory! if allocated_size > self.max_size_bytes: raise AllocationFailed( f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes" ) deadline = None if timeout is None else time.perf_counter() + timeout while self.current_size_bytes + allocated_size > self.max_size_bytes: remaining_time = deadline - time.perf_counter() if timeout is not None else None if not self._memory_freed_event.wait(remaining_time): raise AllocationFailed( f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds" ) self._memory_freed_event.clear() @contextlib.contextmanager def use_cache(self, *handles: Handle) -> Sequence[torch.Tensor]: """ Return one or more tensors previously allocated with allocate_cache, :note: This method is called by ModuleBackend in runtime: a single process with NO process parallelism. However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache """ assert os.getpid() == self.runtime_pid # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here with self._lock_metadata: # read creation/deletion requests from connection handlers while self._pipe_recv.poll(): recv_handles, recv_data = self._pipe_recv.recv() if recv_data is not None: # create new tensors assert len(recv_handles) == len(recv_data) for handle, descr in zip(recv_handles, recv_data): self._allocated_tensors[handle] = descr.make_zeros() assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})" else: # delete tensors by handle for handle in recv_handles: if handle not in self._allocated_tensors: logger.warning( f"Sanity check failed: asked to delete handle {handle}, but there is no such handle" ) self._allocated_tensors.pop(handle, None) yield tuple(self._allocated_tensors[handle] for handle in handles) class AllocationFailed(Exception): pass