memory cache for attention KVs

pull/9/head
justheuristic 2 years ago
parent e5e8c9ed12
commit e2e9d0e94c

@ -1,43 +1,129 @@
"""
A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
TODO In future, one could modify cache to implement, among other things,
- in allocate_cache, if there is not enough memory, wait for memory to be freed by existing tasks up to a given timeout.
- allocate cache as one contigous buffer to avoid fragmentation
- quantize cached values using bitsandbytes
- LRU offloading from gpu to ram
"""
import contextlib
import ctypes
import multiprocessing as mp
from typing import Dict, Tuple
import os
from typing import Dict, Optional, Union
import hivemind
import torch
from hivemind.utils import TensorDescriptor, get_logger
logger = get_logger(__file__)
Handle = int
class MemoryCache:
lock: mp.Lock
runtime_pid: int
handle_counter: mp.Value[ctypes.c_uint64]
current_size: mp.Value[ctypes.c_uint64]
_runtime_data: Dict[int, SomeKindOfTupleWithTensors] # workaround for now, while we are on CPU
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
self.device = device
self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_uint64, 0, lock=False)
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
self.runtime_pid = os.getpid()
self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
@property
def current_size_bytes(self) -> int:
return self._current_size.value
@current_size_bytes.setter
def current_size_bytes(self, value: int):
self._current_size.value = value
@property
def handle_counter(self) -> int:
return self._handle_counter.value
@handle_counter.setter
def handle_counter(self, value: int):
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> Optional[int]:
async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
"""
Allocate buffers for attention cache on the compute device, return a unique handle;
This function should be called by connection handler processes, may be called concurrently
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descr: allocate a tensor of this size, dtype, etc
:note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
Furthermore, it can be called concurrently with at most one use_cache call in runtime.
"""
assert os.getpid() != self.runtime_pid
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert descr.device is None and descr
allocated_handle = None
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
try:
async with acquire_asynchronously(self.lock):
check_and_update_size(current_size, size, dtype)
if enough_space:
self.handle_counter.value += 1
handle = int(self.handle_counter.value)
# note: you cannot allocate data here because this is
TODO_SOMEHOW_COMUNICATE_WITH_RUNTIME_TO_CREATE_THE_RIGHT_DATA
yield handle
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated.")
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, descr))
yield allocated_handle
finally:
todo_deallocate(self, handle)
# ^-- this should NOT move any data. But it may mark data for movement during next allocation
self.data.pop(handle, None);
if allocated_handle is not None:
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
self.current_size_bytes -= allocated_size_bytes
def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
"""Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
@contextlib.contextmanager
def use_cache(self, handle: Handle) -> torch.Tensor:
"""
Return a tensor that was previously allocated with try_allocate_cache,
:note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
"""
assert os.getpid() == self.runtime_pid
with self.lock:
if first_time:
allocate_stuff(self._runtime_data)
yield self.data[handle]
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
with self.lock_metadata:
if self._allocated_tensors is None:
self._allocated_tensors = {}
# read creation/deletion requests from connection handlers
for i in range(int(self._pending_messages.value)):
recv_handle, recv_data = self._pipe_recv.recv()
self._pending_messages.value -= 1
if isinstance(recv_data, TensorDescriptor):
self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
elif recv_data is None:
if recv_handle not in self._allocated_tensors:
logger.warning(
f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
)
self._allocated_tensors.pop(recv_handle, None)
else:
logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
yield self._allocated_tensors[handle]
class AllocationFailed(Exception):
pass

Loading…
Cancel
Save