Rewrite MemoryCache alloc_timeout logic (#434)

-    rpc_inference: server will now accept allocation timeout from user, defaults to no timeout
-    bugfix: inference timeout is now measured from the moment the request is received
    -    previously, you would have to wait for your timeout plus the time it takes to sort through the queue (other users' timeout)
    -    now, you get AllocationFailed if you had to wait for over (timeout) seconds - regardless of other users
-    a request for inference with no timeout will now fail instantly if there is not enough memory available
-    dtype number of bytes is now correctly determined for int, bool & other types


---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
pull/482/head^2
justheuristic 8 months ago committed by GitHub
parent 90840dfea2
commit c08d09c4d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -96,9 +96,9 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--alloc_timeout', type=float, default=1,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--max_alloc_timeout', type=float, default=600,
help="If the cache is full, the server will wait for memory to be freed up to this many seconds"
" before rejecting the request")
parser.add_argument('--revision', type=str, default=None,
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")

@ -16,7 +16,7 @@ from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy
from petals.utils.misc import get_size_in_bytes, is_dummy
logger = get_logger(__name__)
@ -63,7 +63,7 @@ class TransformerBackend(ModuleBackend):
)
self.dtype = backend_dtype
self.dtype_bytes = torch.finfo(self.dtype).bits // 8
self.dtype_bytes = get_size_in_bytes(self.dtype)
self.shard_num_heads = []
for shard in self.module.module_shards:
for submodule in shard.modules():
@ -83,7 +83,7 @@ class TransformerBackend(ModuleBackend):
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step"""

@ -5,6 +5,7 @@ from accelerate import init_empty_weights
from transformers import PretrainedConfig
from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@ -37,7 +38,7 @@ def get_block_size(
if location == "memory":
if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype)
bytes_per_value = torch.finfo(dtype).bits // 8
bytes_per_value = get_size_in_bytes(dtype)
elif quant_type == QuantType.INT8:
bytes_per_value = 1
elif quant_type == QuantType.NF4:
@ -46,6 +47,6 @@ def get_block_size(
raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
bytes_per_value = torch.finfo(dtype).bits // 8
bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps))

@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
@ -166,7 +167,9 @@ class TransformerConnectionHandler(ConnectionHandler):
batch_size = request.tensors[0].size[0] if request.tensors else 1
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
@ -528,14 +531,19 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager
async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
self,
backends: Sequence[TransformerBackend],
*,
batch_size: int,
max_length: int,
timeout: Optional[float],
) -> Sequence[Sequence[Handle]]:
"""
Allocate memory cache for all transformer blocks, return cache handle
:returns: a list of {len(backends)} elements, where i-th element is a tuple of cache handles for i-th backend
"""
descriptors = [backend.get_inference_cache_descriptors(batch_size, max_length) for backend in backends]
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors)) as handles:
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
yield nested_pack(handles, descriptors)
def _log_request(

@ -12,12 +12,13 @@ import os
import time
from typing import AsyncContextManager, Dict, Optional, Sequence
import hivemind
import async_timeout
import torch
from hivemind.utils import TensorDescriptor, get_logger
from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
@ -25,11 +26,12 @@ logger = get_logger(__name__)
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
def __init__(self, max_size_bytes: Optional[int], max_alloc_timeout: Optional[float] = None):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self.max_alloc_timeout = max_alloc_timeout
self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._enqueued_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._allocated_tensors: Dict[Handle, torch.Tensor] = {}
self.runtime_pid = os.getpid()
@ -46,6 +48,14 @@ class MemoryCache:
def current_size_bytes(self, value: int):
self._current_size.value = value
@property
def enqueued_size_bytes(self) -> int:
return self._enqueued_size.value
@enqueued_size_bytes.setter
def enqueued_size_bytes(self, value: int):
self._enqueued_size.value = value
@property
def bytes_left(self) -> int:
return self.max_size_bytes - self.current_size_bytes
@ -59,11 +69,14 @@ class MemoryCache:
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
async def allocate_cache(
self, *descriptors: TensorDescriptor, timeout: float
) -> AsyncContextManager[Sequence[Handle]]:
"""
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descriptors: one or more tensors tensor of this size, dtype, etc
:param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
if not, it will count maximum tensor allocation across devices for the purposes of size limit
@ -73,6 +86,8 @@ class MemoryCache:
"""
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
if self.max_alloc_timeout is not None:
timeout = min(timeout, self.max_alloc_timeout)
max_alloc_size = self.get_allocation_size(*descriptors)
gib = 1024**3
@ -83,10 +98,10 @@ class MemoryCache:
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
)
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
try:
handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
logger.info(f"rpc_inference.alloc_done(size={max_alloc_size / gib:.2f} GiB)")
yield handles
finally:
self._free(max_alloc_size, alloc_task)
@ -96,28 +111,59 @@ class MemoryCache:
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
alloc_size_by_device = {}
for descr in descriptors:
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
return max(alloc_size_by_device.values())
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
async def _schedule_alloc(
self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]
) -> Sequence[Handle]:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
"""
try:
async with self._wait_for_free_memory(alloc_size, timeout):
with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
except TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} (timeout={timeout})")
@contextlib.asynccontextmanager
async def _wait_for_free_memory(self, alloc_size: int, timeout: Optional[float]):
start_time = time.perf_counter()
loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
with self._lock_metadata:
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size
self.handle_counter += len(handles) # note: this will eventually overflow and it is okay
self._pipe_send.send((handles, descriptors))
return handles
def _free(self, alloc_size: int, alloc_task: asyncio.Task) -> None:
self.enqueued_size_bytes += alloc_size
allocated = False
try:
context_manager = async_timeout.timeout(timeout) if timeout != 0 else contextlib.AsyncExitStack()
# contextlib.AsyncExitStack() is used as a null context here
async with context_manager:
if timeout == 0 and self.current_size_bytes + self.enqueued_size_bytes > self.max_size_bytes:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
async with enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
if timeout == 0:
raise AllocationFailed(f"Could not allocate {alloc_size} bytes immediately: out of memory")
elapsed_time = time.perf_counter() - start_time
remaining_timeout = max(0.0, timeout - elapsed_time) if timeout is not None else None
await loop.run_in_executor(None, self._wait_until_available, alloc_size, remaining_timeout)
allocated = True
self.enqueued_size_bytes -= alloc_size
yield
except asyncio.TimeoutError:
raise AllocationFailed(f"Could not allocate {alloc_size} within {timeout} seconds")
finally:
if not allocated:
self.enqueued_size_bytes -= alloc_size
def _free(self, alloc_size: int, alloc_task: asyncio.Task):
if alloc_task.exception() is not None:
return
handles = alloc_task.result()
@ -133,9 +179,10 @@ class MemoryCache:
raise AllocationFailed(
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
)
timeout = timeout if timeout != float("inf") else None
deadline = None if timeout is None else time.perf_counter() + timeout
while self.current_size_bytes + allocated_size > self.max_size_bytes:
remaining_time = deadline - time.perf_counter() if timeout is not None else None
remaining_time = None if timeout is None else deadline - time.perf_counter()
if not self._memory_freed_event.wait(remaining_time):
raise AllocationFailed(
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"

@ -31,6 +31,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.dht import declare_active_modules, get_remote_module_infos
from petals.utils.misc import get_size_in_bytes
from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo
@ -59,12 +60,12 @@ class Server:
min_batch_size: int = 1,
max_batch_size: Optional[int] = None,
max_chunk_size_bytes: int = 256 * 1024 * 1024,
max_alloc_timeout: float = 600,
attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto",
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
@ -185,13 +186,14 @@ class Server:
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.max_chunk_size_bytes = max_chunk_size_bytes
self.max_alloc_timeout = max_alloc_timeout
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
# For disk cache
self.cache_dir = cache_dir
@ -217,8 +219,6 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.alloc_timeout = alloc_timeout
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput_info = get_server_throughput(
@ -311,13 +311,13 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
server_info=self.server_info,
block_indices=block_indices,
num_handlers=self.num_handlers,
min_batch_size=self.min_batch_size,
max_batch_size=self.max_batch_size,
max_chunk_size_bytes=self.max_chunk_size_bytes,
max_alloc_timeout=self.max_alloc_timeout,
inference_max_length=self.inference_max_length,
torch_dtype=self.torch_dtype,
cache_dir=self.cache_dir,
@ -413,12 +413,12 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
alloc_timeout: float,
server_info: ServerInfo,
block_indices: List[int],
min_batch_size: int,
max_batch_size: int,
max_chunk_size_bytes: int,
max_alloc_timeout: float,
torch_dtype: torch.dtype,
cache_dir: str,
max_disk_space: int,
@ -434,7 +434,7 @@ class ModuleContainer(threading.Thread):
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
@ -663,7 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.server_info = server_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
self.bytes_per_token //= block_config.num_key_value_groups
self.update_period = update_period

@ -9,6 +9,16 @@ def is_dummy(tensor: torch.Tensor) -> bool:
return tensor.numel() == 0
SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.qint8: 1, torch.qint32: 4}
def get_size_in_bytes(dtype: torch.dtype) -> int:
if dtype in SPECIAL_DTYPE_SIZES:
return SPECIAL_DTYPE_SIZES[dtype]
get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
def docstring_from(source):
def add_docstring(dest):
dest.__doc__ = source.__doc__

@ -20,6 +20,7 @@ from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__)
@ -285,5 +286,5 @@ def estimate_adapter_memory_per_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
)
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
return adapter_parameters * bytes_per_parameter

@ -0,0 +1,184 @@
import asyncio
import multiprocessing as mp
import random
import time
from typing import Optional
import pytest
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
import torch
from hivemind import TensorDescriptor
from petals.server.memory_cache import AllocationFailed, MemoryCache
from petals.utils.misc import get_size_in_bytes
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
if dtype is None:
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
elem_size_bytes = get_size_in_bytes(dtype)
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
return descr
@pytest.mark.asyncio
async def test_cache_timeout():
cache = MemoryCache(max_size_bytes=1024, max_alloc_timeout=0.5)
cache.runtime_pid += 1 # pretend we're another process
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0):
pass
async with cache.allocate_cache(_make_tensor_descriptor(100), timeout=999):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
async with cache.allocate_cache(_make_tensor_descriptor(128), _make_tensor_descriptor(32), timeout=1):
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=0.1):
pass
assert 0.1 < time.perf_counter() - t_start < 0.2, "wait time exceeds alloc timeout"
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")):
pass
t_start = time.perf_counter()
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(384), timeout=1.0): # exceeds max timeout
pass
assert 0.5 < time.perf_counter() - t_start < 0.6, "wait time exceeds max alloc timeout"
# test memory allocation when another task frees the memory
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
pass
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
async with cache.allocate_cache(_make_tensor_descriptor(128), timeout=float("inf")): # exceeds max timeout
pass # this memory should allocate once the background task clears the queue
assert 0.2 < time.perf_counter() - t_start < 0.3, "memory should be allocated after background task clears"
with pytest.raises(AllocationFailed):
await large_alloc_task
# test that zero-timeout allocation fails instantaneously even if someone else is awaiting alloc
large_alloc_task = asyncio.create_task(_klog_the_cache())
t_start = time.perf_counter()
await asyncio.sleep(0.05) # wait for large alloc to enqueue
with pytest.raises(AllocationFailed):
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0):
pass # this memory should allocate once the background task clears the queue
assert time.perf_counter() - t_start < 0.1, "zero-timeout task should fail (or succeed) instantaneously"
with pytest.raises(AllocationFailed):
await large_alloc_task
@pytest.mark.asyncio
async def test_unlimited_timeout():
cache = MemoryCache(max_size_bytes=1024)
cache.runtime_pid += 1 # pretend we're another process
t_start = time.perf_counter()
async def _klog_the_cache():
async with cache.allocate_cache(_make_tensor_descriptor(512), timeout=0.2):
await asyncio.sleep(0.5)
alloc_task = asyncio.create_task(_klog_the_cache())
await asyncio.sleep(0.1)
async with cache.allocate_cache(_make_tensor_descriptor(768), timeout=float("inf")):
await alloc_task
assert 0.5 < time.perf_counter() - t_start < 0.6, "memory should be allocated after background task clears"
@pytest.mark.asyncio
async def test_cache_usage():
cache = MemoryCache(max_size_bytes=2048)
alloc_event, dealloc_a_event, dealloc_bcd_event, dealloc_e_event, dealloc_f_event = (mp.Event() for _ in range(5))
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
with pytest.raises(AssertionError):
async with cache.allocate_cache(_make_tensor_descriptor(123), timeout=1):
pass # fails because cache must be allocated from another process
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
descr_c = TensorDescriptor.from_tensor(torch.empty((33,), dtype=torch.bool)) # 33 bytes
descr_d = TensorDescriptor.from_tensor(torch.empty((0,), dtype=torch.int64)) # 0 bytes
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
loop = asyncio.get_event_loop()
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
pipe_sender.send(handles)
await loop.run_in_executor(None, dealloc_event.wait)
async def _allocate_af():
alloc_event.wait()
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
await allocate_a_task
allocate_f_task = asyncio.create_task(_allocate_and_wait(dealloc_f_event, descr_f)) # klogs the cache
await allocate_f_task
alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
alloc_process1.start()
async def _allocate_bcde():
alloc_event.wait()
await asyncio.sleep(0.1) # ensure that the other tensor is always allocated (and sent through pipe) first
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
alloc_process2.start()
assert cache.current_size_bytes == 0
alloc_event.set()
(handle_a,) = pipe_receiver.recv()
handle_b, handle_c, handle_d = pipe_receiver.recv()
with cache.use_cache(handle_a) as (tensor_a,):
assert tensor_a.dtype == torch.uint8
tensor_a[2:5] = torch.tensor((42, 43, 44))
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
tensor_a += 1
tensor_b[...] = -1.337
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
dealloc_bcd_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 768 # only tensor a should be allocated
with pytest.raises(KeyError):
with cache.use_cache(handle_a, handle_b):
pass # one of handles (c) is deallocated
with pytest.raises(KeyError):
with cache.use_cache(handle_d):
pass # handle_d is deallocated correctly, even though it is never used
with cache.use_cache(handle_a) as (tensor_a,):
assert tuple(tensor_a[2:5]) == (43, 44, 45)
dealloc_a_event.set()
(handle_e,) = pipe_receiver.recv() # e can finally be allocated
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
with pytest.raises(KeyError):
with cache.use_cache(handle_a):
pass # tensor a is no longer allocated
with cache.use_cache(handle_e) as (tensor_e,):
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
dealloc_e_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 1792 # only tensor f is still allocated
dealloc_f_event.set()
alloc_process1.join()
alloc_process2.join()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 0
assert cache.current_size_bytes == 0
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
Loading…
Cancel
Save