the (still) reasonable version

This commit is contained in:
Your Name 2023-07-21 08:00:34 +03:00
parent b6b3ae964f
commit cc67c332a6
8 changed files with 167 additions and 24 deletions

View File

@ -16,7 +16,7 @@ from transformers import PretrainedConfig
from petals.data_structures import InferenceMetadata from petals.data_structures import InferenceMetadata
from petals.server.memory_cache import MemoryCache from petals.server.memory_cache import MemoryCache
from petals.server.task_pool import PrioritizedTaskPool from petals.server.task_pool import PrioritizedTaskPool
from petals.utils.misc import is_dummy from petals.utils.misc import is_dummy, get_size_in_bytes
logger = get_logger(__name__) logger = get_logger(__name__)
@ -74,7 +74,7 @@ class TransformerBackend(ModuleBackend):
self.cache_bytes_per_token: Dict[torch.device, int] = Counter() self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8 self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
"""Create tensor descriptors for attention cache tensors used during inference_step""" """Create tensor descriptors for attention cache tensors used during inference_step"""

View File

@ -5,6 +5,7 @@ from accelerate import init_empty_weights
from transformers import PretrainedConfig from transformers import PretrainedConfig
from petals.utils.convert_block import QuantType from petals.utils.convert_block import QuantType
from petals.utils.misc import get_size_in_bytes
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
@ -36,7 +37,7 @@ def get_block_size(
if location == "memory": if location == "memory":
if quant_type == QuantType.NONE: if quant_type == QuantType.NONE:
dtype = resolve_block_dtype(config, dtype) dtype = resolve_block_dtype(config, dtype)
bytes_per_value = torch.finfo(dtype).bits // 8 bytes_per_value = get_size_in_bytes(dtype)
elif quant_type == QuantType.INT8: elif quant_type == QuantType.INT8:
bytes_per_value = 1 bytes_per_value = 1
elif quant_type == QuantType.NF4: elif quant_type == QuantType.NF4:
@ -45,6 +46,6 @@ def get_block_size(
raise ValueError(f"Unsupported quant_type={quant_type}") raise ValueError(f"Unsupported quant_type={quant_type}")
elif location == "disk": elif location == "disk":
dtype = resolve_block_dtype(config, "auto") dtype = resolve_block_dtype(config, "auto")
bytes_per_value = torch.finfo(dtype).bits // 8 bytes_per_value = get_size_in_bytes(dtype)
return round(n_params * bytes_per_value * (1 + eps)) return round(n_params * bytes_per_value * (1 + eps))

View File

@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
active_adapter = self._get_active_adapter(metadata) active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0) points = metadata.get("points", 0)
session_id = metadata.get("session_id") session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
if not requested_uids: if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none") raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance( assert isinstance(
@ -167,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler):
batch_size = request.tensors[0].size[0] if request.tensors else 1 batch_size = request.tensors[0].size[0] if request.tensors else 1
prefix_length = 0 prefix_length = 0
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: async with self._allocate_cache(requested_backends, batch_size, max_length, alloc_timeout) as cache_handles:
assert len(cache_handles) == len(requested_backends) assert len(cache_handles) == len(requested_backends)
first_request = request first_request = request
background_tasks = set() background_tasks = set()
@ -567,7 +568,7 @@ class TransformerConnectionHandler(ConnectionHandler):
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _allocate_cache( async def _allocate_cache(
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int, timeout: Optional[float],
) -> Sequence[Sequence[Handle]]: ) -> Sequence[Sequence[Handle]]:
""" """
Allocate memory cache for all transformer blocks, return cache handle Allocate memory cache for all transformer blocks, return cache handle

View File

@ -17,6 +17,7 @@ import torch
from hivemind.utils import TensorDescriptor, get_logger from hivemind.utils import TensorDescriptor, get_logger
from petals.utils.asyncio import shield_and_wait from petals.utils.asyncio import shield_and_wait
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__) logger = get_logger(__name__)
@ -26,9 +27,8 @@ Handle = int
class MemoryCache: class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): def __init__(self, max_size_bytes: Optional[int]):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self._lock_metadata = mp.Lock() self._lock_metadata = mp.Lock()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
@ -60,11 +60,12 @@ class MemoryCache:
self._handle_counter.value = value self._handle_counter.value = value
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]: async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]:
""" """
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
:param descriptors: one or more tensors tensor of this size, dtype, etc :param descriptors: one or more tensors tensor of this size, dtype, etc
:param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
if not, it will count maximum tensor allocation across devices for the purposes of size limit if not, it will count maximum tensor allocation across devices for the purposes of size limit
@ -76,7 +77,7 @@ class MemoryCache:
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
max_alloc_size = self.get_allocation_size(*descriptors) max_alloc_size = self.get_allocation_size(*descriptors)
gib = 1024**3 gib = 1
cur_size, max_size = self.current_size_bytes, self.max_size_bytes cur_size, max_size = self.current_size_bytes, self.max_size_bytes
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
logger.info( logger.info(
@ -84,24 +85,26 @@ class MemoryCache:
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
) )
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors)) alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
try: try:
handles = await shield_and_wait(alloc_task) handles = await shield_and_wait(alloc_task)
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)")
yield handles yield handles
finally: finally:
logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)")
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)")
@staticmethod @staticmethod
def get_allocation_size(*descriptors: TensorDescriptor) -> int: def get_allocation_size(*descriptors: TensorDescriptor) -> int:
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
alloc_size_by_device = {} alloc_size_by_device = {}
for descr in descriptors: for descr in descriptors:
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
return max(alloc_size_by_device.values()) return max(alloc_size_by_device.values())
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]: async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]:
""" """
This method should be called inside asyncio.shield() because: This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
@ -110,7 +113,7 @@ class MemoryCache:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes: if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata): async with hivemind.utils.enter_asynchronously(self._lock_metadata):
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
self.current_size_bytes += alloc_size self.current_size_bytes += alloc_size
@ -124,7 +127,6 @@ class MemoryCache:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
- _schedule_free() must finish freeing memory even in case of cancellation - _schedule_free() must finish freeing memory even in case of cancellation
""" """
if alloc_task.exception() is not None: if alloc_task.exception() is not None:
return return
handles = alloc_task.result() handles = alloc_task.result()

View File

@ -31,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
from petals.server.throughput import get_dtype_name, get_server_throughput from petals.server.throughput import get_dtype_name, get_server_throughput
from petals.utils.auto_config import AutoDistributedConfig from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.convert_block import QuantType, check_device_balance, convert_block from petals.utils.convert_block import QuantType, check_device_balance, convert_block
from petals.utils.misc import get_size_in_bytes
from petals.utils.ping import PingAggregator from petals.utils.ping import PingAggregator
from petals.utils.random import sample_up_to from petals.utils.random import sample_up_to
from petals.utils.version import get_compatible_model_repo from petals.utils.version import get_compatible_model_repo
@ -63,7 +64,6 @@ class Server:
revision: Optional[str] = None, revision: Optional[str] = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None, max_disk_space: Optional[int] = None,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE, compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None, stats_report_interval: Optional[int] = None,
@ -189,7 +189,7 @@ class Server:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192 attn_cache_tokens = 32768 if is_multiquery_attn else 8192
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
# For disk cache # For disk cache
self.cache_dir = cache_dir self.cache_dir = cache_dir
@ -213,8 +213,6 @@ class Server:
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.alloc_timeout = alloc_timeout
assert isinstance(throughput, float) or throughput in ["auto", "eval"] assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]: if throughput in ["auto", "eval"]:
throughput_info = get_server_throughput( throughput_info = get_server_throughput(
@ -306,7 +304,6 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path, converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config, block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes, attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
server_info=self.server_info, server_info=self.server_info,
block_indices=block_indices, block_indices=block_indices,
num_handlers=self.num_handlers, num_handlers=self.num_handlers,
@ -407,7 +404,6 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str, converted_model_name_or_path: str,
block_config: PretrainedConfig, block_config: PretrainedConfig,
attn_cache_bytes: int, attn_cache_bytes: int,
alloc_timeout: float,
server_info: ServerInfo, server_info: ServerInfo,
block_indices: List[int], block_indices: List[int],
min_batch_size: int, min_batch_size: int,
@ -427,7 +423,7 @@ class ModuleContainer(threading.Thread):
**kwargs, **kwargs,
) -> ModuleContainer: ) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) memory_cache = MemoryCache(attn_cache_bytes)
server_info.state = ServerState.JOINING server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread( dht_announcer = ModuleAnnouncerThread(

View File

@ -5,3 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
def is_dummy(tensor: torch.Tensor): def is_dummy(tensor: torch.Tensor):
return tensor.numel() == 0 return tensor.numel() == 0
SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.int8: 1, torch.qint32: 4}
def get_size_in_bytes(dtype: torch.dtype) -> int:
if dtype in SPECIAL_DTYPE_SIZES:
return SPECIAL_DTYPE_SIZES[dtype]
get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8

View File

@ -19,6 +19,7 @@ from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import get_size_in_bytes
logger = get_logger(__name__) logger = get_logger(__name__)
@ -284,5 +285,5 @@ def estimate_adapter_memory_per_block(
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
) )
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8 bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
return adapter_parameters * bytes_per_parameter return adapter_parameters * bytes_per_parameter

132
tests/test_cache.py Normal file
View File

@ -0,0 +1,132 @@
import random
from typing import Optional
import pytest
import torch
from hivemind import TensorDescriptor
from petals.server.memory_cache import MemoryCache, AllocationFailed
import asyncio
from petals.utils.misc import get_size_in_bytes
import multiprocessing as mp
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
if dtype is None:
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
elem_size_bytes = get_size_in_bytes(dtype)
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
return descr
@pytest.mark.asyncio
async def test_cache_usage():
cache = MemoryCache(max_size_bytes=2048)
alloc_event, dealloc_e_event, dealloc_bcd_event, dealloc_a_event = mp.Event(), mp.Event(), mp.Event(), mp.Event()
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
with pytest.raises(AssertionError):
async with cache.allocate_cache(_make_tensor_descriptor(123)):
pass # fails because cache must be allocated from another process
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
descr_c = TensorDescriptor.from_tensor(torch.empty((33, ), dtype=torch.bool)) # 33 bytes
descr_d = TensorDescriptor.from_tensor(torch.empty((0, ), dtype=torch.int64)) # 0 bytes
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
descr_g = TensorDescriptor.from_tensor(torch.empty((1793,), dtype=torch.uint8)) # 1792 bytes
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
loop = asyncio.get_event_loop()
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
pipe_sender.send(handles)
await loop.run_in_executor(None, dealloc_event.wait)
async def _allocate_af():
alloc_event.wait()
print("BEGAN AF")
try:
async with cache.allocate_cache(descr_g):
allocate_f_task = asyncio.create_task(_allocate_and_wait(mp.Event(), descr_f)) # klogs the cache
print("CANCELLED")
raise asyncio.CancelledError()
except asyncio.CancelledError:
pass
allocate_f_task.cancel() # unklog the cache
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
await allocate_a_task
alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
alloc_process1.start()
async def _allocate_bcde():
await asyncio.sleep(0.2) # ensure that the other tensor is always allocated (and sent through pipe) first
print("BEGAN BCDE")
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
alloc_process2.start()
assert cache.current_size_bytes == 0
alloc_event.set()
handle_a, = pipe_receiver.recv()
handle_b, handle_c, handle_d = pipe_receiver.recv()
with cache.use_cache(handle_a) as (tensor_a,):
assert tensor_a.dtype == torch.uint8
tensor_a[2:5] = torch.tensor((42, 43, 44))
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
tensor_a += 1
tensor_b[...] = -1.337
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
dealloc_bcd_event.set()
await asyncio.sleep(0.1)
assert cache.current_size_bytes == 768 # only tensor a is allocated
with pytest.raises(KeyError):
with cache.use_cache(handle_a, handle_b):
pass # one of handles (c) is deallocated
with pytest.raises(KeyError):
with cache.use_cache(handle_d):
pass # handle_e is deallocated, even though it is never used
with cache.use_cache(handle_a) as (tensor_a,):
assert tuple(tensor_a[2:5]) == (43, 44, 45)
dealloc_a_event.set()
handle_e, = pipe_receiver.recv() # e can finally be allocated
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
with pytest.raises(KeyError):
with cache.use_cache(handle_a):
pass # tensor a is no longer allocated
with cache.use_cache(handle_e) as (tensor_e,):
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
dealloc_e_event.set()
alloc_process1.join(1)
alloc_process2.join(1)
assert cache.current_size_bytes == 0
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
# cache.runtime_pid += 1 # pretend we're another process
# async with cache.allocate_cache(_make_tensor_descriptor(768)) as a:
# pass
#
#
# async with cache.allocate_cache(_make_tensor_descriptor(768)):
# async with cache.allocate_cache(_make_tensor_descriptor(1024)):
# async with cache.allocate_cache(_make_tensor_descriptor(512), _make_tensor_descriptor(64)):
# async with cache.allocate_cache(_make_tensor_descriptor(1536)):
# with pytest.raises(TimeoutError):
# async with cache.allocate_cache(_make_tensor_descriptor(256), ):
# pass
# async with cache.allocate_cache(_make_tensor_descriptor(192)):
# pass