From 9997ada3bbeacfa2264f8df1f13f4d1d783f48e5 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Thu, 22 Dec 2022 20:05:57 +0400 Subject: [PATCH] Shield alloc & free from cancellation (#163) A handler's RPC code may be cancelled due to a request timeout or a client closing the connection. Before this PR: - If `.cancel()` happens while waiting for `hivemind.utils.enter_asynchronously()`, the lock will never be released. - If `.cancel()` happens while doing that before freeing memory, the memory will never be freed. This PR fixes it by deferring the cancellation with [asyncio.shield()](https://docs.python.org/3/library/asyncio-task.html#asyncio.shield). Now, the cancellation will happen only when all locks are released and alloc/free has completed. --- src/petals/server/handler.py | 34 +++++++++++++----- src/petals/server/memory_cache.py | 60 ++++++++++++++++++++----------- src/petals/utils/asyncio.py | 21 +++++++++++ 3 files changed, 87 insertions(+), 28 deletions(-) create mode 100644 src/petals/utils/asyncio.py diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index a7b72c6..ff66e4b 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -1,6 +1,6 @@ import asyncio import contextlib -from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union import torch from async_timeout import timeout @@ -93,7 +93,12 @@ class TransformerConnectionHandler(ConnectionHandler): """Compute a single step of inference using attention cache; update attention cache accordingly.""" async with timeout(self.session_timeout): - request = await asyncio.wait_for(anext(requests), self.step_timeout) + try: + request = await asyncio.wait_for(anext(requests), self.step_timeout) + except asyncio.TimeoutError: + self._log_request("rpc_inference.open", None, context, warning="timed out") + return + requested_uids = self._check_uids(request.uid) self._log_request("rpc_inference.open", requested_uids, context) try: @@ -193,7 +198,11 @@ class TransformerConnectionHandler(ConnectionHandler): # prepare for next step prefix_length += hidden_states.shape[1] - request = await asyncio.wait_for(anext(requests), self.step_timeout) + try: + request = await asyncio.wait_for(anext(requests), self.step_timeout) + except asyncio.TimeoutError: + self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") + return finally: self._log_request("rpc_inference.close", requested_uids, context) @@ -369,14 +378,23 @@ class TransformerConnectionHandler(ConnectionHandler): logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)") yield handle - def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None: - friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] - friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] - friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids + def _log_request( + self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None + ) -> None: + if uids is not None: + friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] + friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] + friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids + else: + friendly_uids = "n/a" friendly_remote_id = "..." + str(context.remote_id)[-6:] - logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})") + message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})" + if warning is None: + logger.info(message) + else: + logger.warning(f"{message}: {warning}") async def _rpc_forward( diff --git a/src/petals/server/memory_cache.py b/src/petals/server/memory_cache.py index ac7af41..53c1a7d 100644 --- a/src/petals/server/memory_cache.py +++ b/src/petals/server/memory_cache.py @@ -16,6 +16,8 @@ import hivemind import torch from hivemind.utils import TensorDescriptor, get_logger +from petals.utils.asyncio import shield_and_wait + logger = get_logger(__file__) Handle = int @@ -66,28 +68,46 @@ class MemoryCache: """ assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" assert descr.device is None and descr - allocated_handle = None - allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8 - loop = asyncio.get_event_loop() + + alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8 + alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr)) try: - async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): - if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes: - await loop.run_in_executor( - None, self._wait_until_available, allocated_size_bytes, self.alloc_timeout - ) - async with hivemind.utils.enter_asynchronously(self._lock_metadata): - allocated_handle = int(self.handle_counter) - self.current_size_bytes += allocated_size_bytes - self.handle_counter += 1 # note: this will eventually overflow and it is okay - self._pipe_send.send((allocated_handle, descr)) - - yield allocated_handle + yield await shield_and_wait(alloc_task) finally: - if allocated_handle is not None: - async with hivemind.utils.enter_asynchronously(self._lock_metadata): - self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle - self.current_size_bytes -= allocated_size_bytes - self._memory_freed_event.set() + await shield_and_wait(self._schedule_free(alloc_size, alloc_task)) + + async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle: + """ + This method should be called inside asyncio.shield() because: + - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation + """ + + loop = asyncio.get_event_loop() + async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): + if self.current_size_bytes + alloc_size > self.max_size_bytes: + await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout) + async with hivemind.utils.enter_asynchronously(self._lock_metadata): + handle = int(self.handle_counter) + self.current_size_bytes += alloc_size + self.handle_counter += 1 # note: this will eventually overflow and it is okay + self._pipe_send.send((handle, descr)) + return handle + + async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task): + """ + This method should be called inside asyncio.shield() because: + - hivemind.utils.enter_asynchronously() does not always release the lock on cancellation + - _schedule_free() must finish freeing memory even in case of cancellation + """ + + if alloc_task.exception() is not None: + return + handle = alloc_task.result() + + async with hivemind.utils.enter_asynchronously(self._lock_metadata): + self._pipe_send.send((handle, None)) # signal runtime to free that handle + self.current_size_bytes -= alloc_size + self._memory_freed_event.set() def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None): # note: this function should only be called inside _lock_acquire_memory! diff --git a/src/petals/utils/asyncio.py b/src/petals/utils/asyncio.py new file mode 100644 index 0000000..72d85ce --- /dev/null +++ b/src/petals/utils/asyncio.py @@ -0,0 +1,21 @@ +import asyncio + + +async def shield_and_wait(task): + """ + Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller. + """ + + if not isinstance(task, asyncio.Task): + task = asyncio.create_task(task) + + cancel_exc = None + while True: + try: + result = await asyncio.shield(task) + break + except asyncio.CancelledError as e: + cancel_exc = e + if cancel_exc is not None: + raise cancel_exc + return result