diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index d61470a..99165fe 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -16,7 +16,7 @@ from transformers import PretrainedConfig from petals.data_structures import InferenceMetadata from petals.server.memory_cache import MemoryCache from petals.server.task_pool import PrioritizedTaskPool -from petals.utils.misc import is_dummy +from petals.utils.misc import is_dummy, get_size_in_bytes logger = get_logger(__name__) @@ -74,7 +74,7 @@ class TransformerBackend(ModuleBackend): self.cache_bytes_per_token: Dict[torch.device, int] = Counter() for descr in self.get_inference_cache_descriptors(batch_size=1, max_length=1): - self.cache_bytes_per_token[descr.device] += descr.numel() * torch.finfo(descr.dtype).bits // 8 + self.cache_bytes_per_token[descr.device] += descr.numel() * get_size_in_bytes(descr.dtype) def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]: """Create tensor descriptors for attention cache tensors used during inference_step""" diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py index eb5300e..899dcb9 100644 --- a/src/petals/server/block_utils.py +++ b/src/petals/server/block_utils.py @@ -5,6 +5,7 @@ from accelerate import init_empty_weights from transformers import PretrainedConfig from petals.utils.convert_block import QuantType +from petals.utils.misc import get_size_in_bytes def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype: @@ -36,7 +37,7 @@ def get_block_size( if location == "memory": if quant_type == QuantType.NONE: dtype = resolve_block_dtype(config, dtype) - bytes_per_value = torch.finfo(dtype).bits // 8 + bytes_per_value = get_size_in_bytes(dtype) elif quant_type == QuantType.INT8: bytes_per_value = 1 elif quant_type == QuantType.NF4: @@ -45,6 +46,6 @@ def get_block_size( raise ValueError(f"Unsupported quant_type={quant_type}") elif location == "disk": dtype = resolve_block_dtype(config, "auto") - bytes_per_value = torch.finfo(dtype).bits // 8 + bytes_per_value = get_size_in_bytes(dtype) return round(n_params * bytes_per_value * (1 + eps)) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 5d0a3d4..c47e644 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -150,6 +150,7 @@ class TransformerConnectionHandler(ConnectionHandler): active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") + alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -167,7 +168,7 @@ class TransformerConnectionHandler(ConnectionHandler): batch_size = request.tensors[0].size[0] if request.tensors else 1 prefix_length = 0 - async with self._allocate_cache(requested_backends, batch_size, max_length) as cache_handles: + async with self._allocate_cache(requested_backends, batch_size, max_length, alloc_timeout) as cache_handles: assert len(cache_handles) == len(requested_backends) first_request = request background_tasks = set() @@ -567,7 +568,7 @@ class TransformerConnectionHandler(ConnectionHandler): @contextlib.asynccontextmanager async def _allocate_cache( - self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int + self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int, timeout: Optional[float], ) -> Sequence[Sequence[Handle]]: """ Allocate memory cache for all transformer blocks, return cache handle diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index a1e2f26..2c5fcdf 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -17,6 +17,7 @@ import torch from hivemind.utils import TensorDescriptor, get_logger from petals.utils.asyncio import shield_and_wait +from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) @@ -26,9 +27,8 @@ Handle = int class MemoryCache: """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs""" - def __init__(self, max_size_bytes: Optional[int], alloc_timeout: float): + def __init__(self, max_size_bytes: Optional[int]): self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1) - self.alloc_timeout = alloc_timeout self._lock_metadata = mp.Lock() self._current_size = mp.Value(ctypes.c_int64, 0, lock=False) self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False) @@ -60,11 +60,12 @@ class MemoryCache: self._handle_counter.value = value @contextlib.asynccontextmanager - async def allocate_cache(self, *descriptors: TensorDescriptor) -> AsyncContextManager[Sequence[Handle]]: + async def allocate_cache(self, *descriptors: TensorDescriptor, timeout: Optional[float] = None) -> AsyncContextManager[Sequence[Handle]]: """ Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed. :param descriptors: one or more tensors tensor of this size, dtype, etc + :param timeout: optional maximum time to wait for cache allocation; None (default) means no time limit :note: if descriptors reside on different devices, it is expected that they are approximately balanced across devices; if not, it will count maximum tensor allocation across devices for the purposes of size limit @@ -76,7 +77,7 @@ class MemoryCache: assert all(descr.device is not None for descr in descriptors), "please specify allocated devices" max_alloc_size = self.get_allocation_size(*descriptors) - gib = 1024**3 + gib = 1 cur_size, max_size = self.current_size_bytes, self.max_size_bytes friendly_max_size = f"{max_size / gib:.2f}" if max_size != 2**64 - 1 else "inf" logger.info( @@ -84,24 +85,26 @@ class MemoryCache: f"already used {cur_size / gib:.2f}/{friendly_max_size} GiB ({cur_size / max_size * 100:.1f}%)" ) - alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors)) + alloc_task = asyncio.create_task(self._schedule_alloc(max_alloc_size, *descriptors, timeout=timeout)) try: handles = await shield_and_wait(alloc_task) - logger.info(f"rpc_inference.alloc(size={max_alloc_size / gib:.2f} GiB)") + logger.info(f"rpc_inference.alloc-done(size={max_alloc_size / gib:.2f} GiB)") yield handles finally: + logger.info(f"rpc_inference.dealloc-began(size={max_alloc_size / gib:.2f} GiB)") await shield_and_wait(self._schedule_free(max_alloc_size, alloc_task)) + logger.info(f"rpc_inference.dealloc-done(size={max_alloc_size / gib:.2f} GiB)") @staticmethod def get_allocation_size(*descriptors: TensorDescriptor) -> int: """Return the memory size (bytes) to be allocated on a device. If there are many devices, return maximum""" alloc_size_by_device = {} for descr in descriptors: - tensor_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 + tensor_size = descr.numel() * get_size_in_bytes(descr.dtype) alloc_size_by_device[descr.device] = alloc_size_by_device.get(descr.device, 0) + tensor_size return max(alloc_size_by_device.values()) - async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor) -> Sequence[Handle]: + async def _schedule_alloc(self, alloc_size: int, *descriptors: TensorDescriptor, timeout: Optional[float]) -> Sequence[Handle]: """ This method should be called inside asyncio.shield() because: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation @@ -110,7 +113,7 @@ class MemoryCache: loop = asyncio.get_event_loop() async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): if self.current_size_bytes + alloc_size > self.max_size_bytes: - await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) + await loop.run_in_executor(None, self._wait_until_available, alloc_size, timeout) async with hivemind.utils.enter_asynchronously(self._lock_metadata): handles = tuple(int(self.handle_counter) + i for i in range(len(descriptors))) self.current_size_bytes += alloc_size @@ -124,7 +127,6 @@ class MemoryCache: - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation - _schedule_free() must finish freeing memory even in case of cancellation """ - if alloc_task.exception() is not None: return handles = alloc_task.result() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 5cdca46..a8986c8 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -31,6 +31,7 @@ from petals.server.reachability import ReachabilityProtocol, check_direct_reacha from petals.server.throughput import get_dtype_name, get_server_throughput from petals.utils.auto_config import AutoDistributedConfig from petals.utils.convert_block import QuantType, check_device_balance, convert_block +from petals.utils.misc import get_size_in_bytes from petals.utils.ping import PingAggregator from petals.utils.random import sample_up_to from petals.utils.version import get_compatible_model_repo @@ -63,7 +64,6 @@ class Server: revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, stats_report_interval: Optional[int] = None, @@ -189,7 +189,7 @@ class Server: attn_cache_tokens = 32768 if is_multiquery_attn else 8192 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups - self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 + self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype) # For disk cache self.cache_dir = cache_dir @@ -213,8 +213,6 @@ class Server: self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") - self.alloc_timeout = alloc_timeout - assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput_info = get_server_throughput( @@ -306,7 +304,6 @@ class Server: converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, - alloc_timeout=self.alloc_timeout, server_info=self.server_info, block_indices=block_indices, num_handlers=self.num_handlers, @@ -407,7 +404,6 @@ class ModuleContainer(threading.Thread): converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, - alloc_timeout: float, server_info: ServerInfo, block_indices: List[int], min_batch_size: int, @@ -427,7 +423,7 @@ class ModuleContainer(threading.Thread): **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] - memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes) server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 2f67202..1b754d4 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -5,3 +5,13 @@ DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter par def is_dummy(tensor: torch.Tensor): return tensor.numel() == 0 + + +SPECIAL_DTYPE_SIZES = {torch.bool: 1, torch.int8: 1, torch.qint32: 4} + + +def get_size_in_bytes(dtype: torch.dtype) -> int: + if dtype in SPECIAL_DTYPE_SIZES: + return SPECIAL_DTYPE_SIZES[dtype] + get_info = torch.finfo if dtype.is_floating_point else torch.iinfo + return (get_info(dtype).bits * (1 + dtype.is_complex)) // 8 diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index da25623..ec3b193 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -19,6 +19,7 @@ from transformers.utils import get_file_from_repo from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.misc import get_size_in_bytes logger = get_logger(__name__) @@ -284,5 +285,5 @@ def estimate_adapter_memory_per_block( block, block_index=0, adapter_name=adapter, peft_config=peft_config, peft_state_dict=peft_state_dict ) adapter_parameters = sum(p.numel() for p in block.parameters()) - base_block_parameters - bytes_per_parameter = torch.finfo(resolve_block_dtype(block_config, torch_dtype)).bits / 8 + bytes_per_parameter = get_size_in_bytes(resolve_block_dtype(block_config, torch_dtype)) return adapter_parameters * bytes_per_parameter diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..68bae3b --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,132 @@ +import random +from typing import Optional + +import pytest +import torch +from hivemind import TensorDescriptor + +from petals.server.memory_cache import MemoryCache, AllocationFailed +import asyncio +from petals.utils.misc import get_size_in_bytes +import multiprocessing as mp +import pytest_asyncio # make sure the module exists; otherwise the test will be skipped + + +def _make_tensor_descriptor(num_bytes: int, dtype: Optional[torch.dtype] = None): + if dtype is None: + dtype = random.choice((torch.int64, torch.int8, torch.uint8, torch.float32, torch.bfloat16, torch.bool)) + elem_size_bytes = get_size_in_bytes(dtype) + descr = TensorDescriptor.from_tensor(torch.empty((num_bytes // elem_size_bytes,), dtype=dtype)) + return descr + + +@pytest.mark.asyncio +async def test_cache_usage(): + cache = MemoryCache(max_size_bytes=2048) + alloc_event, dealloc_e_event, dealloc_bcd_event, dealloc_a_event = mp.Event(), mp.Event(), mp.Event(), mp.Event() + pipe_receiver, pipe_sender = mp.Pipe(duplex=False) + with pytest.raises(AssertionError): + async with cache.allocate_cache(_make_tensor_descriptor(123)): + pass # fails because cache must be allocated from another process + + descr_a = TensorDescriptor.from_tensor(torch.empty(768, dtype=torch.uint8)) # 768 bytes + descr_b = TensorDescriptor.from_tensor(torch.empty((), dtype=torch.float64)) # 8 bytes + descr_c = TensorDescriptor.from_tensor(torch.empty((33, ), dtype=torch.bool)) # 33 bytes + descr_d = TensorDescriptor.from_tensor(torch.empty((0, ), dtype=torch.int64)) # 0 bytes + descr_e = TensorDescriptor.from_tensor(torch.empty((96, 8), dtype=torch.bfloat16)) # 1536 bytes + descr_f = TensorDescriptor.from_tensor(torch.empty((1792,), dtype=torch.uint8)) # 1792 bytes + descr_g = TensorDescriptor.from_tensor(torch.empty((1793,), dtype=torch.uint8)) # 1792 bytes + + async def _allocate_and_wait(dealloc_event, *descrs, timeout=None): + loop = asyncio.get_event_loop() + async with cache.allocate_cache(*descrs, timeout=timeout) as handles: + pipe_sender.send(handles) + await loop.run_in_executor(None, dealloc_event.wait) + + async def _allocate_af(): + alloc_event.wait() + print("BEGAN AF") + try: + async with cache.allocate_cache(descr_g): + allocate_f_task = asyncio.create_task(_allocate_and_wait(mp.Event(), descr_f)) # klogs the cache + print("CANCELLED") + raise asyncio.CancelledError() + except asyncio.CancelledError: + pass + allocate_f_task.cancel() # unklog the cache + + allocate_a_task = asyncio.create_task(_allocate_and_wait(dealloc_a_event, descr_a)) + await allocate_a_task + + alloc_process1 = mp.Process(target=lambda: asyncio.run(_allocate_af()), daemon=True) + alloc_process1.start() + + async def _allocate_bcde(): + await asyncio.sleep(0.2) # ensure that the other tensor is always allocated (and sent through pipe) first + print("BEGAN BCDE") + allocate_bcd_task = asyncio.create_task(_allocate_and_wait(dealloc_bcd_event, descr_b, descr_c, descr_d)) + allocate_e_task = asyncio.create_task(_allocate_and_wait(dealloc_e_event, descr_e)) # doesn't fit + await asyncio.wait({allocate_e_task, allocate_bcd_task}, return_when=asyncio.ALL_COMPLETED) + + alloc_process2 = mp.Process(target=lambda: asyncio.run(_allocate_bcde()), daemon=True) + alloc_process2.start() + assert cache.current_size_bytes == 0 + alloc_event.set() + handle_a, = pipe_receiver.recv() + + handle_b, handle_c, handle_d = pipe_receiver.recv() + + with cache.use_cache(handle_a) as (tensor_a,): + assert tensor_a.dtype == torch.uint8 + tensor_a[2:5] = torch.tensor((42, 43, 44)) + + with cache.use_cache(handle_a, handle_b, handle_d) as (tensor_a, tensor_b, tensor_d): + assert tensor_b.dtype == torch.float64 and tensor_b.numel() == 1 and tensor_b.ndim == 0 + assert tensor_d.dtype == torch.int64 and tensor_d.numel() == 0 + tensor_a += 1 + tensor_b[...] = -1.337 + assert cache.current_size_bytes == 809 # this checks a,b,c,d are allocated but b still awaits memory + + dealloc_bcd_event.set() + await asyncio.sleep(0.1) + assert cache.current_size_bytes == 768 # only tensor a is allocated + with pytest.raises(KeyError): + with cache.use_cache(handle_a, handle_b): + pass # one of handles (c) is deallocated + with pytest.raises(KeyError): + with cache.use_cache(handle_d): + pass # handle_e is deallocated, even though it is never used + with cache.use_cache(handle_a) as (tensor_a,): + assert tuple(tensor_a[2:5]) == (43, 44, 45) + + dealloc_a_event.set() + handle_e, = pipe_receiver.recv() # e can finally be allocated + assert cache.current_size_bytes == 1536 # tensor e should finally be able to allocate + + with pytest.raises(KeyError): + with cache.use_cache(handle_a): + pass # tensor a is no longer allocated + with cache.use_cache(handle_e) as (tensor_e,): + assert tensor_e.dtype == torch.bfloat16 and tensor_e.shape == (96, 8) + + dealloc_e_event.set() + alloc_process1.join(1) + alloc_process2.join(1) + assert cache.current_size_bytes == 0 + assert alloc_process1.exitcode == 0, "allocation process 1 failed or did not finish, see stderr for details" + assert alloc_process2.exitcode == 0, "allocation process 2 failed or did not finish, see stderr for details" + + # cache.runtime_pid += 1 # pretend we're another process + # async with cache.allocate_cache(_make_tensor_descriptor(768)) as a: + # pass + # + # + # async with cache.allocate_cache(_make_tensor_descriptor(768)): + # async with cache.allocate_cache(_make_tensor_descriptor(1024)): + # async with cache.allocate_cache(_make_tensor_descriptor(512), _make_tensor_descriptor(64)): + # async with cache.allocate_cache(_make_tensor_descriptor(1536)): + # with pytest.raises(TimeoutError): + # async with cache.allocate_cache(_make_tensor_descriptor(256), ): + # pass + # async with cache.allocate_cache(_make_tensor_descriptor(192)): + # pass