Shield alloc & free from cancellation (#163)

A handler's RPC code may be cancelled due to a request timeout or a client closing the connection. Before this PR:

- If `.cancel()` happens while waiting for `hivemind.utils.enter_asynchronously()`, the lock will never be released.
- If `.cancel()` happens while doing that before freeing memory, the memory will never be freed.

This PR fixes it by deferring the cancellation with [asyncio.shield()](https://docs.python.org/3/library/asyncio-task.html#asyncio.shield). Now, the cancellation will happen only when all locks are released and alloc/free has completed.
pull/165/head
Alexander Borzunov 1 year ago committed by GitHub
parent d6992fca63
commit 9997ada3bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
import asyncio
import contextlib
from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
from async_timeout import timeout
@ -93,7 +93,12 @@ class TransformerConnectionHandler(ConnectionHandler):
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
async with timeout(self.session_timeout):
request = await asyncio.wait_for(anext(requests), self.step_timeout)
try:
request = await asyncio.wait_for(anext(requests), self.step_timeout)
except asyncio.TimeoutError:
self._log_request("rpc_inference.open", None, context, warning="timed out")
return
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_inference.open", requested_uids, context)
try:
@ -193,7 +198,11 @@ class TransformerConnectionHandler(ConnectionHandler):
# prepare for next step
prefix_length += hidden_states.shape[1]
request = await asyncio.wait_for(anext(requests), self.step_timeout)
try:
request = await asyncio.wait_for(anext(requests), self.step_timeout)
except asyncio.TimeoutError:
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
return
finally:
self._log_request("rpc_inference.close", requested_uids, context)
@ -369,14 +378,23 @@ class TransformerConnectionHandler(ConnectionHandler):
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
yield handle
def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
def _log_request(
self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
) -> None:
if uids is not None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
else:
friendly_uids = "n/a"
friendly_remote_id = "..." + str(context.remote_id)[-6:]
logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})")
message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
if warning is None:
logger.info(message)
else:
logger.warning(f"{message}: {warning}")
async def _rpc_forward(

@ -16,6 +16,8 @@ import hivemind
import torch
from hivemind.utils import TensorDescriptor, get_logger
from petals.utils.asyncio import shield_and_wait
logger = get_logger(__file__)
Handle = int
@ -66,28 +68,46 @@ class MemoryCache:
"""
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert descr.device is None and descr
allocated_handle = None
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
loop = asyncio.get_event_loop()
alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
try:
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
await loop.run_in_executor(
None, self._wait_until_available, allocated_size_bytes, self.alloc_timeout
)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pipe_send.send((allocated_handle, descr))
yield allocated_handle
yield await shield_and_wait(alloc_task)
finally:
if allocated_handle is not None:
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
self.current_size_bytes -= allocated_size_bytes
self._memory_freed_event.set()
await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
"""
loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
handle = int(self.handle_counter)
self.current_size_bytes += alloc_size
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pipe_send.send((handle, descr))
return handle
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
- _schedule_free() must finish freeing memory even in case of cancellation
"""
if alloc_task.exception() is not None:
return
handle = alloc_task.result()
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((handle, None)) # signal runtime to free that handle
self.current_size_bytes -= alloc_size
self._memory_freed_event.set()
def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
# note: this function should only be called inside _lock_acquire_memory!

@ -0,0 +1,21 @@
import asyncio
async def shield_and_wait(task):
"""
Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller.
"""
if not isinstance(task, asyncio.Task):
task = asyncio.create_task(task)
cancel_exc = None
while True:
try:
result = await asyncio.shield(task)
break
except asyncio.CancelledError as e:
cancel_exc = e
if cancel_exc is not None:
raise cancel_exc
return result
Loading…
Cancel
Save