diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index d85c8ac..720d64d 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -96,9 +96,9 @@ def main(): parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--alloc_timeout', type=float, default=1, - help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' - 'before rejecting the request') + parser.add_argument('--max_alloc_timeout', type=float, default=600, + help="If the cache is full, the server will wait for memory to be freed up to this many seconds" + " before rejecting the request") parser.add_argument('--revision', type=str, default=None, help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models" "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.") diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 8b788b0..3a9b63e 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -16,7 +16,7 @@ from transformers import PretrainedConfig from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool -from petals.utils.misc import is_dummy +from petals.utils.misc import get_size_in_bytes, is_dummy logger = get_logger(__name__) @@ -63,7 +63,7 @@ class TransformerBackend(ModuleBackend): ) self.dtype = backend_dtype - self.dtype_bytes = torch.finfo(self.dtype).bits // 8 + self.dtype_bytes = get_size_in_bytes(self.dtype) self.shard_num_heads = [] for shard in self.module.module_shards: for submodule in shard.modules(): @@ -83,7 +83,7 @@ class TransformerBackend(ModuleBackend): self.cache_bytes_per_token: Dict[torch.device, int] = Counter() for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): - self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8 + self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype) def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index effce82..ac0995d 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -5,6 +5,7 @@ from accelerate import init_empty_weights from transformers import PretrainedConfig from petals.utils.convert_block import QuantType +from petals.utils.misc import get_size_in_bytes def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: @@ -37,7 +38,7 @@ def get_block_size( if location == "memory": if quant_type == QuantType.NONE: dtype = resolve_block_dtype(config, dtype) - bytes_per_value = torch.finfo(dtype).bits // 8 + bytes_per_value = get_size_in_bytes(dtype) elif quant_type == QuantType.INT8: bytes_per_value = 1 elif quant_type == QuantType.NF4: @@ -46,6 +47,6 @@ def get_block_size( raise ValueError(f"Unsupported quant_type={quant_type}") elif location == "disk": dtype = resolve_block_dtype(config, "auto") - bytes_per_value = torch.finfo(dtype).bits // 8 + bytes_per_value = get_size_in_bytes(dtype) return round(n_params * bytes_per_value * (1 + eps)) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 0dd63bd..d8f0ec0 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler): max_length = metadata.get("max_length") points = metadata.get("points", 0) session_id = metadata.get("session_id") + alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") @@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler): batch_size = request.tensors[0].size[0] if request.tensors else 1 - async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: + async with self._allocate_cache( + requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout + ) as cache_handles: background_tasks = set() async for output_tensors, can_push in iterate_rpc_inference( requested_uids=requested_uids, @@ -528,14 +531,19 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( - self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int + self, + backends: Sequence[TransformerBackend], + *, + batch_size: int, + max_length: int, + timeout: Optional[float], ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle :returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend """ descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends] - async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles: + async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles: yield nested_pack(handles, descriptors) def _log_request( diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index 9e79f17..6fb3895 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -12,12 +12,13 @@ import os import time from typing import AsyncContextManager, Dict, Optional, Sequence -import hivemind +import async_timeout import torch -from hivemind.utils import TensorDescriptor, get_logger +from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger from petals.data_structures import Handle from petals.utils.asyncio import shield_and_wait +from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) @@ -25,11 +26,12 @@ logger = get_logger(__name__) class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" - def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): + def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) - self.alloc_timeout = alloc_timeout + self.max_alloc_timeout = max_alloc_timeout self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) + self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._allocated_tensors: Dict[Handle, torch.Tensor] = {} self.runtime_pid = os.getpid() @@ -46,6 +48,14 @@ class MemoryCache: def current_size_bytes(self, value: int): self._current_size.value = value + @property + def enqueued_size_bytes(self) -> int: + return self._enqueued_size.value + + @enqueued_size_bytes.setter + def enqueued_size_bytes(self, value: int): + self._enqueued_size.value = value + @property def bytes_left(self) -> int: return self.max_size_bytes - self.current_size_bytes @@ -59,11 +69,14 @@ class MemoryCache: self._handle_counter.value = value @contextlib.asynccontextmanager - async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]: + async def allocate_cache( + self, *descriptors: TensorDescriptor, timeout: float + ) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. :param descriptors: one or more tensors tensor of this size, dtype, etc + :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; if not, it will count maximum tensor allocation across devices for the purposes of size limit @@ -73,6 +86,8 @@ class MemoryCache: """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" + if self.max_alloc_timeout is not None: + timeout = min(timeout, self.max_alloc_timeout) max_alloc_size = self.get_allocation_size(*descriptors) gib = 1024**3 @@ -83,10 +98,10 @@ class MemoryCache: f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" ) - alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors)) + alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout)) try: handles = await shield_and_wait(alloc_task) - logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") + logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: self._free(max_alloc_size, alloc_task) @@ -96,28 +111,59 @@ class MemoryCache: """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" alloc_size_by_device = {} for descr in descriptors: - tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 + tensor_size = descr.numel() * get_size_in_bytes(descr.dtype) alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size return max(alloc_size_by_device.values()) - async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]: + async def _schedule_alloc( + self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float] + ) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation """ + try: + async with self._wait_for_free_memory(alloc_size, timeout): + with self._lock_metadata: + handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) + self.current_size_bytes += alloc_size + self.handle_counter += len(handles) # note: this will eventually overflow and it is okay + self._pipe_send.send((handles, descriptors)) + return handles + except TimeoutError: + raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})") + @contextlib.asynccontextmanager + async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]): + start_time = time.perf_counter() loop = asyncio.get_event_loop() - async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): - if self.current_size_bytes + alloc_size > self.max_size_bytes: - await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) - with self._lock_metadata: - handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) - self.current_size_bytes += alloc_size - self.handle_counter += len(handles) # note: this will eventually overflow and it is okay - self._pipe_send.send((handles, descriptors)) - return handles - - def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None: + + self.enqueued_size_bytes += alloc_size + allocated = False + try: + context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack() + # contextlib.AsyncExitStack() is used as a null context here + async with context_manager: + if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes: + raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory") + async with enter_asynchronously(self._lock_acquire_memory): + if self.current_size_bytes + alloc_size > self.max_size_bytes: + if timeout == 0: + raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory") + elapsed_time = time.perf_counter() - start_time + remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None + await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout) + + allocated = True + self.enqueued_size_bytes -= alloc_size + yield + except asyncio.TimeoutError: + raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds") + finally: + if not allocated: + self.enqueued_size_bytes -= alloc_size + + def _free(self, alloc_size: int, alloc_task: asyncio.Task): if alloc_task.exception() is not None: return handles = alloc_task.result() @@ -133,9 +179,10 @@ class MemoryCache: raise AllocationFailed( f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes" ) + timeout = timeout if timeout != float("inf") else None deadline = None if timeout is None else time.perf_counter() + timeout while self.current_size_bytes + allocated_size > self.max_size_bytes: - remaining_time = deadline - time.perf_counter() if timeout is not None else None + remaining_time = None if timeout is None else deadline - time.perf_counter() if not self._memory_freed_event.wait(remaining_time): raise AllocationFailed( f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds" diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ba0403c..40865aa 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -31,6 +31,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.dht import declare_active_modules, get_remote_module_infos +from petals.utils.misc import get_size_in_bytes from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo @@ -59,12 +60,12 @@ class Server: min_batch_size: int = 1, max_batch_size: Optional[int] = None, max_chunk_size_bytes: int = 256 * 1024 * 1024, + max_alloc_timeout: float = 600, attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, @@ -185,13 +186,14 @@ class Server: self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.inference_max_length = inference_max_length self.max_chunk_size_bytes = max_chunk_size_bytes + self.max_alloc_timeout = max_alloc_timeout # For attention cache in GPU or RAM if attn_cache_tokens is None: attn_cache_tokens = 32768 if is_multiquery_attn else 8192 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups - self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) # For disk cache self.cache_dir = cache_dir @@ -217,8 +219,6 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - self.alloc_timeout = alloc_timeout - assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput_info = get_server_throughput( @@ -311,13 +311,13 @@ class Server: converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, - alloc_timeout=self.alloc_timeout, server_info=self.server_info, block_indices=block_indices, num_handlers=self.num_handlers, min_batch_size=self.min_batch_size, max_batch_size=self.max_batch_size, max_chunk_size_bytes=self.max_chunk_size_bytes, + max_alloc_timeout=self.max_alloc_timeout, inference_max_length=self.inference_max_length, torch_dtype=self.torch_dtype, cache_dir=self.cache_dir, @@ -413,12 +413,12 @@ class ModuleContainer(threading.Thread): converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, - alloc_timeout: float, server_info: ServerInfo, block_indices: List[int], min_batch_size: int, max_batch_size: int, max_chunk_size_bytes: int, + max_alloc_timeout: float, torch_dtype: torch.dtype, cache_dir: str, max_disk_space: int, @@ -434,7 +434,7 @@ class ModuleContainer(threading.Thread): **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] - memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout) server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( @@ -663,7 +663,7 @@ class ModuleAnnouncerThread(threading.Thread): self.server_info = server_info self.memory_cache = memory_cache - self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 + self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype]) self.bytes_per_token //= block_config.num_key_value_groups self.update_period = update_period diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index afe9fc4..d0cfd7c 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool: return tensor.numel() == 0 +SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4} + + +def get_size_in_bytes(dtype: torch.dtype) -> int: + if dtype in SPECIAL_DTYPE_SIZES: + return SPECIAL_DTYPE_SIZES[dtype] + get_info = torch.finfo if dtype.is_floating_point else torch.iinfo + return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8 + + def docstring_from(source): def add_docstring(dest): dest.__doc__ = source.__doc__ diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c7e3d05..e4d29fc 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) @@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict ) adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters - bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8 + bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype)) return adapter_parameters * bytes_per_parameter diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..6d40db1 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,184 @@ +import asyncio +import multiprocessing as mp +import random +import time +from typing import Optional + +import pytest +import pytest_asyncio # make sure the module exists; otherwise the test will be skipped +import torch +from hivemind import TensorDescriptor + +from petals.server.memory_cache import AllocationFailed, MemoryCache +from petals.utils.misc import get_size_in_bytes + + +def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None): + if dtype is None: + dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool)) + elem_size_bytes = get_size_in_bytes(dtype) + descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype)) + return descr + + +@pytest.mark.asyncio +async def test_cache_timeout(): + cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5) + cache.runtime_pid += 1 # pretend we're another process + async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0): + pass + + async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999): + async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0): + async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1): + t_start = time.perf_counter() + with pytest.raises(AllocationFailed): + async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1): + pass + assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout" + async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): + pass + + t_start = time.perf_counter() + with pytest.raises(AllocationFailed): + async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout + pass + assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout" + + # test memory allocation when another task frees the memory + async def _klog_the_cache(): + async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2): + pass + + large_alloc_task = asyncio.create_task(_klog_the_cache()) + + t_start = time.perf_counter() + await asyncio.sleep(0.05) # wait for large alloc to enqueue + async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout + pass # this memory should allocate once the background task clears the queue + assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears" + with pytest.raises(AllocationFailed): + await large_alloc_task + + # test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc + large_alloc_task = asyncio.create_task(_klog_the_cache()) + t_start = time.perf_counter() + await asyncio.sleep(0.05) # wait for large alloc to enqueue + with pytest.raises(AllocationFailed): + async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0): + pass # this memory should allocate once the background task clears the queue + assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously" + with pytest.raises(AllocationFailed): + await large_alloc_task + + +@pytest.mark.asyncio +async def test_unlimited_timeout(): + cache = MemoryCache(max_size_bytes=1024) + cache.runtime_pid += 1 # pretend we're another process + t_start = time.perf_counter() + + async def _klog_the_cache(): + async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2): + await asyncio.sleep(0.5) + + alloc_task = asyncio.create_task(_klog_the_cache()) + await asyncio.sleep(0.1) + async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")): + await alloc_task + assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears" + + +@pytest.mark.asyncio +async def test_cache_usage(): + cache = MemoryCache(max_size_bytes=2048) + alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5)) + pipe_receiver, pipe_sender = mp.Pipe(duplex=False) + with pytest.raises(AssertionError): + async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1): + pass # fails because cache must be allocated from another process + + descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes + descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes + descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes + descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes + descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes + descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes + + async def _allocate_and_wait(dealloc_event, *descrs, timeout=None): + loop = asyncio.get_event_loop() + async with cache.allocate_cache(*descrs, timeout=timeout) as handles: + pipe_sender.send(handles) + await loop.run_in_executor(None, dealloc_event.wait) + + async def _allocate_af(): + alloc_event.wait() + allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a)) + await allocate_a_task + allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache + await allocate_f_task + + alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True) + alloc_process1.start() + + async def _allocate_bcde(): + alloc_event.wait() + await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first + allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d)) + allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit + await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED) + + alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) + alloc_process2.start() + assert cache.current_size_bytes == 0 + alloc_event.set() + (handle_a,) = pipe_receiver.recv() + + handle_b, handle_c, handle_d = pipe_receiver.recv() + + with cache.use_cache(handle_a) as (tensor_a,): + assert tensor_a.dtype == torch.uint8 + tensor_a[2:5] = torch.tensor((42, 43, 44)) + + with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d): + assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0 + assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0 + tensor_a += 1 + tensor_b[...] = -1.337 + assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory + + dealloc_bcd_event.set() + await asyncio.sleep(0.1) + assert cache.current_size_bytes == 768 # only tensor a should be allocated + with pytest.raises(KeyError): + with cache.use_cache(handle_a, handle_b): + pass # one of handles (c) is deallocated + with pytest.raises(KeyError): + with cache.use_cache(handle_d): + pass # handle_d is deallocated correctly, even though it is never used + with cache.use_cache(handle_a) as (tensor_a,): + assert tuple(tensor_a[2:5]) == (43, 44, 45) + + dealloc_a_event.set() + (handle_e,) = pipe_receiver.recv() # e can finally be allocated + await asyncio.sleep(0.1) + assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate + + with pytest.raises(KeyError): + with cache.use_cache(handle_a): + pass # tensor a is no longer allocated + with cache.use_cache(handle_e) as (tensor_e,): + assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8) + + dealloc_e_event.set() + await asyncio.sleep(0.1) + assert cache.current_size_bytes == 1792 # only tensor f is still allocated + dealloc_f_event.set() + + alloc_process1.join() + alloc_process2.join() + await asyncio.sleep(0.1) + assert cache.current_size_bytes == 0 + assert cache.current_size_bytes == 0 + assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details" + assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"