Make attention cache wait until memory is freed (#53)

Previously, attempting to allocate with MemoryCache that does not have enough space would throw AllocationFailed.

PR changes this behavior to the following:
- by default, wait until memory is freed by other tenants (FIFO)
- if could not allocate within timeout, throw AllocationFailed
- if allocated size is too big to fit even in empty cache, throw AllocationFailed

- [x] passes existing tests
- [x] passes manual load tests

p.s. if anyone wondered: using mp.Condition will not make the code simpler, their lock behavior is slightly different to what we need here

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
pull/70/head
justheuristic 2 years ago committed by GitHub
parent 8a0c056929
commit f3984b192a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,10 +4,12 @@ A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and u
For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
"""
import asyncio
import contextlib
import ctypes
import multiprocessing as mp
import os
import time
from typing import AsyncContextManager, Dict, Optional, Union
import hivemind
@ -27,7 +29,7 @@ class MemoryCache:
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.device = device
self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
@ -36,6 +38,8 @@ class MemoryCache:
self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
self._lock_acquire_memory = mp.Lock()
self._memory_freed_event = mp.Event()
@property
def current_size_bytes(self) -> int:
@ -67,27 +71,39 @@ class MemoryCache:
assert descr.device is None and descr
allocated_handle = None
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
loop = asyncio.get_event_loop()
try:
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
raise AllocationFailed(
f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
)
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, descr))
await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, descr))
yield allocated_handle
finally:
if allocated_handle is not None:
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
self.current_size_bytes -= allocated_size_bytes
self._memory_freed_event.set()
def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
# note: this function should only be called inside _lock_acquire_memory!
if allocated_size_bytes > self.max_size_bytes:
raise AllocationFailed(
f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
)
deadline = None if timeout is None else time.perf_counter() + timeout
while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
remaining_time = deadline - time.perf_counter() if timeout is not None else None
if not self._memory_freed_event.wait(remaining_time):
raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
self._memory_freed_event.clear()
@contextlib.contextmanager
def use_cache(self, handle: Handle) -> torch.Tensor:
@ -100,7 +116,7 @@ class MemoryCache:
assert os.getpid() == self.runtime_pid
# note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
with self.lock_metadata:
with self._lock_metadata:
if self._allocated_tensors is None:
self._allocated_tensors = {}

Loading…
Cancel
Save