mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
the (still) reasonable version
This commit is contained in:
parent
b6b3ae964f
commit
cc67c332a6
@ -16,7 +16,7 @@ from transformers import PretrainedConfig
|
|||||||
from petals.data_structures import InferenceMetadata
|
from petals.data_structures import InferenceMetadata
|
||||||
from petals.server.memory_cache import MemoryCache
|
from petals.server.memory_cache import MemoryCache
|
||||||
from petals.server.task_pool import PrioritizedTaskPool
|
from petals.server.task_pool import PrioritizedTaskPool
|
||||||
from petals.utils.misc import is_dummy
|
from petals.utils.misc import is_dummy, get_size_in_bytes
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -74,7 +74,7 @@ class TransformerBackend(ModuleBackend):
|
|||||||
|
|
||||||
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
|
self.cache_bytes_per_token: Dict[torch.device, int] = Counter()
|
||||||
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
|
for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1):
|
||||||
self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8
|
self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype)
|
||||||
|
|
||||||
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
|
def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
|
||||||
"""Create tensor descriptors for attention cache tensors used during inference_step"""
|
"""Create tensor descriptors for attention cache tensors used during inference_step"""
|
||||||
|
@ -5,6 +5,7 @@ from accelerate import init_empty_weights
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from petals.utils.convert_block import QuantType
|
from petals.utils.convert_block import QuantType
|
||||||
|
from petals.utils.misc import get_size_in_bytes
|
||||||
|
|
||||||
|
|
||||||
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
|
def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
|
||||||
@ -36,7 +37,7 @@ def get_block_size(
|
|||||||
if location == "memory":
|
if location == "memory":
|
||||||
if quant_type == QuantType.NONE:
|
if quant_type == QuantType.NONE:
|
||||||
dtype = resolve_block_dtype(config, dtype)
|
dtype = resolve_block_dtype(config, dtype)
|
||||||
bytes_per_value = torch.finfo(dtype).bits // 8
|
bytes_per_value = get_size_in_bytes(dtype)
|
||||||
elif quant_type == QuantType.INT8:
|
elif quant_type == QuantType.INT8:
|
||||||
bytes_per_value = 1
|
bytes_per_value = 1
|
||||||
elif quant_type == QuantType.NF4:
|
elif quant_type == QuantType.NF4:
|
||||||
@ -45,6 +46,6 @@ def get_block_size(
|
|||||||
raise ValueError(f"Unsupported quant_type={quant_type}")
|
raise ValueError(f"Unsupported quant_type={quant_type}")
|
||||||
elif location == "disk":
|
elif location == "disk":
|
||||||
dtype = resolve_block_dtype(config, "auto")
|
dtype = resolve_block_dtype(config, "auto")
|
||||||
bytes_per_value = torch.finfo(dtype).bits // 8
|
bytes_per_value = get_size_in_bytes(dtype)
|
||||||
|
|
||||||
return round(n_params * bytes_per_value * (1 + eps))
|
return round(n_params * bytes_per_value * (1 + eps))
|
||||||
|
@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
active_adapter = self._get_active_adapter(metadata)
|
active_adapter = self._get_active_adapter(metadata)
|
||||||
points = metadata.get("points", 0)
|
points = metadata.get("points", 0)
|
||||||
session_id = metadata.get("session_id")
|
session_id = metadata.get("session_id")
|
||||||
|
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
|
||||||
if not requested_uids:
|
if not requested_uids:
|
||||||
raise ValueError("User must specify at least one block for inference, but got none")
|
raise ValueError("User must specify at least one block for inference, but got none")
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@ -167,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
batch_size = request.tensors[0].size[0] if request.tensors else 1
|
||||||
prefix_length = 0
|
prefix_length = 0
|
||||||
|
|
||||||
async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles:
|
async with self._allocate_cache(requested_backends, batch_size, max_length, alloc_timeout) as cache_handles:
|
||||||
assert len(cache_handles) == len(requested_backends)
|
assert len(cache_handles) == len(requested_backends)
|
||||||
first_request = request
|
first_request = request
|
||||||
background_tasks = set()
|
background_tasks = set()
|
||||||
@ -567,7 +568,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _allocate_cache(
|
async def _allocate_cache(
|
||||||
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
|
self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int, timeout: Optional[float],
|
||||||
) -> Sequence[Sequence[Handle]]:
|
) -> Sequence[Sequence[Handle]]:
|
||||||
"""
|
"""
|
||||||
Allocate memory cache for all transformer blocks, return cache handle
|
Allocate memory cache for all transformer blocks, return cache handle
|
||||||
|
@ -17,6 +17,7 @@ import torch
|
|||||||
from hivemind.utils import TensorDescriptor, get_logger
|
from hivemind.utils import TensorDescriptor, get_logger
|
||||||
|
|
||||||
from petals.utils.asyncio import shield_and_wait
|
from petals.utils.asyncio import shield_and_wait
|
||||||
|
from petals.utils.misc import get_size_in_bytes
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -26,9 +27,8 @@ Handle = int
|
|||||||
class MemoryCache:
|
class MemoryCache:
|
||||||
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
|
||||||
|
|
||||||
def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float):
|
def __init__(self, max_size_bytes: Optional[int]):
|
||||||
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
|
||||||
self.alloc_timeout = alloc_timeout
|
|
||||||
self._lock_metadata = mp.Lock()
|
self._lock_metadata = mp.Lock()
|
||||||
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
|
||||||
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
|
||||||
@ -60,11 +60,12 @@ class MemoryCache:
|
|||||||
self._handle_counter.value = value
|
self._handle_counter.value = value
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]:
|
async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]:
|
||||||
"""
|
"""
|
||||||
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
|
||||||
|
|
||||||
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
:param descriptors: one or more tensors tensor of this size, dtype, etc
|
||||||
|
:param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit
|
||||||
|
|
||||||
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
:note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices;
|
||||||
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
if not, it will count maximum tensor allocation across devices for the purposes of size limit
|
||||||
@ -76,7 +77,7 @@ class MemoryCache:
|
|||||||
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
assert all(descr.device is not None for descr in descriptors), "please specify allocated devices"
|
||||||
max_alloc_size = self.get_allocation_size(*descriptors)
|
max_alloc_size = self.get_allocation_size(*descriptors)
|
||||||
|
|
||||||
gib = 1024**3
|
gib = 1
|
||||||
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
|
cur_size, max_size = self.current_size_bytes, self.max_size_bytes
|
||||||
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf"
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -84,24 +85,26 @@ class MemoryCache:
|
|||||||
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)"
|
||||||
)
|
)
|
||||||
|
|
||||||
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors))
|
alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout))
|
||||||
try:
|
try:
|
||||||
handles = await shield_and_wait(alloc_task)
|
handles = await shield_and_wait(alloc_task)
|
||||||
logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)")
|
logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)")
|
||||||
yield handles
|
yield handles
|
||||||
finally:
|
finally:
|
||||||
|
logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)")
|
||||||
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task))
|
||||||
|
logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
def get_allocation_size(*descriptors: TensorDescriptor) -> int:
|
||||||
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
"""Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum"""
|
||||||
alloc_size_by_device = {}
|
alloc_size_by_device = {}
|
||||||
for descr in descriptors:
|
for descr in descriptors:
|
||||||
tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
|
tensor_size = descr.numel() * get_size_in_bytes(descr.dtype)
|
||||||
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size
|
||||||
return max(alloc_size_by_device.values())
|
return max(alloc_size_by_device.values())
|
||||||
|
|
||||||
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]:
|
async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]:
|
||||||
"""
|
"""
|
||||||
This method should be called inside asyncio.shield() because:
|
This method should be called inside asyncio.shield() because:
|
||||||
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
||||||
@ -110,7 +113,7 @@ class MemoryCache:
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
|
||||||
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
if self.current_size_bytes + alloc_size > self.max_size_bytes:
|
||||||
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
|
await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout)
|
||||||
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
|
||||||
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors)))
|
||||||
self.current_size_bytes += alloc_size
|
self.current_size_bytes += alloc_size
|
||||||
@ -124,7 +127,6 @@ class MemoryCache:
|
|||||||
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
|
||||||
- _schedule_free() must finish freeing memory even in case of cancellation
|
- _schedule_free() must finish freeing memory even in case of cancellation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if alloc_task.exception() is not None:
|
if alloc_task.exception() is not None:
|
||||||
return
|
return
|
||||||
handles = alloc_task.result()
|
handles = alloc_task.result()
|
||||||
|
@ -31,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha
|
|||||||
from petals.server.throughput import get_dtype_name, get_server_throughput
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
||||||
from petals.utils.auto_config import AutoDistributedConfig
|
from petals.utils.auto_config import AutoDistributedConfig
|
||||||
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
||||||
|
from petals.utils.misc import get_size_in_bytes
|
||||||
from petals.utils.ping import PingAggregator
|
from petals.utils.ping import PingAggregator
|
||||||
from petals.utils.random import sample_up_to
|
from petals.utils.random import sample_up_to
|
||||||
from petals.utils.version import get_compatible_model_repo
|
from petals.utils.version import get_compatible_model_repo
|
||||||
@ -63,7 +64,6 @@ class Server:
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
max_disk_space: Optional[int] = None,
|
max_disk_space: Optional[int] = None,
|
||||||
alloc_timeout: float = 5,
|
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
compression=CompressionType.NONE,
|
compression=CompressionType.NONE,
|
||||||
stats_report_interval: Optional[int] = None,
|
stats_report_interval: Optional[int] = None,
|
||||||
@ -189,7 +189,7 @@ class Server:
|
|||||||
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
|
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
|
||||||
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
||||||
cache_values_per_block //= self.block_config.num_key_value_groups
|
cache_values_per_block //= self.block_config.num_key_value_groups
|
||||||
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
|
||||||
|
|
||||||
# For disk cache
|
# For disk cache
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
@ -213,8 +213,6 @@ class Server:
|
|||||||
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
||||||
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
||||||
|
|
||||||
self.alloc_timeout = alloc_timeout
|
|
||||||
|
|
||||||
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
||||||
if throughput in ["auto", "eval"]:
|
if throughput in ["auto", "eval"]:
|
||||||
throughput_info = get_server_throughput(
|
throughput_info = get_server_throughput(
|
||||||
@ -306,7 +304,6 @@ class Server:
|
|||||||
converted_model_name_or_path=self.converted_model_name_or_path,
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
||||||
block_config=self.block_config,
|
block_config=self.block_config,
|
||||||
attn_cache_bytes=self.attn_cache_bytes,
|
attn_cache_bytes=self.attn_cache_bytes,
|
||||||
alloc_timeout=self.alloc_timeout,
|
|
||||||
server_info=self.server_info,
|
server_info=self.server_info,
|
||||||
block_indices=block_indices,
|
block_indices=block_indices,
|
||||||
num_handlers=self.num_handlers,
|
num_handlers=self.num_handlers,
|
||||||
@ -407,7 +404,6 @@ class ModuleContainer(threading.Thread):
|
|||||||
converted_model_name_or_path: str,
|
converted_model_name_or_path: str,
|
||||||
block_config: PretrainedConfig,
|
block_config: PretrainedConfig,
|
||||||
attn_cache_bytes: int,
|
attn_cache_bytes: int,
|
||||||
alloc_timeout: float,
|
|
||||||
server_info: ServerInfo,
|
server_info: ServerInfo,
|
||||||
block_indices: List[int],
|
block_indices: List[int],
|
||||||
min_batch_size: int,
|
min_batch_size: int,
|
||||||
@ -427,7 +423,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModuleContainer:
|
) -> ModuleContainer:
|
||||||
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
||||||
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
memory_cache = MemoryCache(attn_cache_bytes)
|
||||||
|
|
||||||
server_info.state = ServerState.JOINING
|
server_info.state = ServerState.JOINING
|
||||||
dht_announcer = ModuleAnnouncerThread(
|
dht_announcer = ModuleAnnouncerThread(
|
||||||
|
@ -5,3 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par
|
|||||||
|
|
||||||
def is_dummy(tensor: torch.Tensor):
|
def is_dummy(tensor: torch.Tensor):
|
||||||
return tensor.numel() == 0
|
return tensor.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
|
SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.int8: 1, torch.qint32: 4}
|
||||||
|
|
||||||
|
|
||||||
|
def get_size_in_bytes(dtype: torch.dtype) -> int:
|
||||||
|
if dtype in SPECIAL_DTYPE_SIZES:
|
||||||
|
return SPECIAL_DTYPE_SIZES[dtype]
|
||||||
|
get_info = torch.finfo if dtype.is_floating_point else torch.iinfo
|
||||||
|
return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8
|
||||||
|
@ -19,6 +19,7 @@ from transformers.utils import get_file_from_repo
|
|||||||
from petals.server.block_utils import resolve_block_dtype
|
from petals.server.block_utils import resolve_block_dtype
|
||||||
from petals.utils.convert_block import QuantType
|
from petals.utils.convert_block import QuantType
|
||||||
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||||
|
from petals.utils.misc import get_size_in_bytes
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -284,5 +285,5 @@ def estimate_adapter_memory_per_block(
|
|||||||
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict
|
||||||
)
|
)
|
||||||
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters
|
||||||
bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8
|
bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype))
|
||||||
return adapter_parameters * bytes_per_parameter
|
return adapter_parameters * bytes_per_parameter
|
||||||
|
132
tests/test_cache.py
Normal file
132
tests/test_cache.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
import random
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from hivemind import TensorDescriptor
|
||||||
|
|
||||||
|
from petals.server.memory_cache import MemoryCache, AllocationFailed
|
||||||
|
import asyncio
|
||||||
|
from petals.utils.misc import get_size_in_bytes
|
||||||
|
import multiprocessing as mp
|
||||||
|
import pytest_asyncio # make sure the module exists; otherwise the test will be skipped
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None):
|
||||||
|
if dtype is None:
|
||||||
|
dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool))
|
||||||
|
elem_size_bytes = get_size_in_bytes(dtype)
|
||||||
|
descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype))
|
||||||
|
return descr
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_usage():
|
||||||
|
cache = MemoryCache(max_size_bytes=2048)
|
||||||
|
alloc_event, dealloc_e_event, dealloc_bcd_event, dealloc_a_event = mp.Event(), mp.Event(), mp.Event(), mp.Event()
|
||||||
|
pipe_receiver, pipe_sender = mp.Pipe(duplex=False)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
async with cache.allocate_cache(_make_tensor_descriptor(123)):
|
||||||
|
pass # fails because cache must be allocated from another process
|
||||||
|
|
||||||
|
descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes
|
||||||
|
descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes
|
||||||
|
descr_c = TensorDescriptor.from_tensor(torch.empty((33, ), dtype=torch.bool)) # 33 bytes
|
||||||
|
descr_d = TensorDescriptor.from_tensor(torch.empty((0, ), dtype=torch.int64)) # 0 bytes
|
||||||
|
descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes
|
||||||
|
descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes
|
||||||
|
descr_g = TensorDescriptor.from_tensor(torch.empty((1793,), dtype=torch.uint8)) # 1792 bytes
|
||||||
|
|
||||||
|
async def _allocate_and_wait(dealloc_event, *descrs, timeout=None):
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
async with cache.allocate_cache(*descrs, timeout=timeout) as handles:
|
||||||
|
pipe_sender.send(handles)
|
||||||
|
await loop.run_in_executor(None, dealloc_event.wait)
|
||||||
|
|
||||||
|
async def _allocate_af():
|
||||||
|
alloc_event.wait()
|
||||||
|
print("BEGAN AF")
|
||||||
|
try:
|
||||||
|
async with cache.allocate_cache(descr_g):
|
||||||
|
allocate_f_task = asyncio.create_task(_allocate_and_wait(mp.Event(), descr_f)) # klogs the cache
|
||||||
|
print("CANCELLED")
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
allocate_f_task.cancel() # unklog the cache
|
||||||
|
|
||||||
|
allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a))
|
||||||
|
await allocate_a_task
|
||||||
|
|
||||||
|
alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True)
|
||||||
|
alloc_process1.start()
|
||||||
|
|
||||||
|
async def _allocate_bcde():
|
||||||
|
await asyncio.sleep(0.2) # ensure that the other tensor is always allocated (and sent through pipe) first
|
||||||
|
print("BEGAN BCDE")
|
||||||
|
allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d))
|
||||||
|
allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit
|
||||||
|
await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED)
|
||||||
|
|
||||||
|
alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True)
|
||||||
|
alloc_process2.start()
|
||||||
|
assert cache.current_size_bytes == 0
|
||||||
|
alloc_event.set()
|
||||||
|
handle_a, = pipe_receiver.recv()
|
||||||
|
|
||||||
|
handle_b, handle_c, handle_d = pipe_receiver.recv()
|
||||||
|
|
||||||
|
with cache.use_cache(handle_a) as (tensor_a,):
|
||||||
|
assert tensor_a.dtype == torch.uint8
|
||||||
|
tensor_a[2:5] = torch.tensor((42, 43, 44))
|
||||||
|
|
||||||
|
with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d):
|
||||||
|
assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0
|
||||||
|
assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0
|
||||||
|
tensor_a += 1
|
||||||
|
tensor_b[...] = -1.337
|
||||||
|
assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory
|
||||||
|
|
||||||
|
dealloc_bcd_event.set()
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert cache.current_size_bytes == 768 # only tensor a is allocated
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
with cache.use_cache(handle_a, handle_b):
|
||||||
|
pass # one of handles (c) is deallocated
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
with cache.use_cache(handle_d):
|
||||||
|
pass # handle_e is deallocated, even though it is never used
|
||||||
|
with cache.use_cache(handle_a) as (tensor_a,):
|
||||||
|
assert tuple(tensor_a[2:5]) == (43, 44, 45)
|
||||||
|
|
||||||
|
dealloc_a_event.set()
|
||||||
|
handle_e, = pipe_receiver.recv() # e can finally be allocated
|
||||||
|
assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate
|
||||||
|
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
with cache.use_cache(handle_a):
|
||||||
|
pass # tensor a is no longer allocated
|
||||||
|
with cache.use_cache(handle_e) as (tensor_e,):
|
||||||
|
assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8)
|
||||||
|
|
||||||
|
dealloc_e_event.set()
|
||||||
|
alloc_process1.join(1)
|
||||||
|
alloc_process2.join(1)
|
||||||
|
assert cache.current_size_bytes == 0
|
||||||
|
assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details"
|
||||||
|
assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details"
|
||||||
|
|
||||||
|
# cache.runtime_pid += 1 # pretend we're another process
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(768)) as a:
|
||||||
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(768)):
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(1024)):
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(512), _make_tensor_descriptor(64)):
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(1536)):
|
||||||
|
# with pytest.raises(TimeoutError):
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(256), ):
|
||||||
|
# pass
|
||||||
|
# async with cache.allocate_cache(_make_tensor_descriptor(192)):
|
||||||
|
# pass
|
Loading…
Reference in New Issue
Block a user