Shield alloc & free from cancellation (#163)

A handler's RPC code may be cancelled due to a request timeout or a client closing the connection. Before this PR:

- If `.cancel()` happens while waiting for `hivemind.utils.enter_asynchronously()`, the lock will never be released.
- If `.cancel()` happens while doing that before freeing memory, the memory will never be freed.

This PR fixes it by deferring the cancellation with [asyncio.shield()](https://docs.python.org/3/library/asyncio-task.html#asyncio.shield). Now, the cancellation will happen only when all locks are released and alloc/free has completed.
pull/165/head
Alexander Borzunov 1 year ago committed by GitHub
parent d6992fca63
commit 9997ada3bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
import asyncio import asyncio
import contextlib import contextlib
from typing import Any, AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch import torch
from async_timeout import timeout from async_timeout import timeout
@ -93,7 +93,12 @@ class TransformerConnectionHandler(ConnectionHandler):
"""Compute a single step of inference using attention cache; update attention cache accordingly.""" """Compute a single step of inference using attention cache; update attention cache accordingly."""
async with timeout(self.session_timeout): async with timeout(self.session_timeout):
request = await asyncio.wait_for(anext(requests), self.step_timeout) try:
request = await asyncio.wait_for(anext(requests), self.step_timeout)
except asyncio.TimeoutError:
self._log_request("rpc_inference.open", None, context, warning="timed out")
return
requested_uids = self._check_uids(request.uid) requested_uids = self._check_uids(request.uid)
self._log_request("rpc_inference.open", requested_uids, context) self._log_request("rpc_inference.open", requested_uids, context)
try: try:
@ -193,7 +198,11 @@ class TransformerConnectionHandler(ConnectionHandler):
# prepare for next step # prepare for next step
prefix_length += hidden_states.shape[1] prefix_length += hidden_states.shape[1]
request = await asyncio.wait_for(anext(requests), self.step_timeout) try:
request = await asyncio.wait_for(anext(requests), self.step_timeout)
except asyncio.TimeoutError:
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
return
finally: finally:
self._log_request("rpc_inference.close", requested_uids, context) self._log_request("rpc_inference.close", requested_uids, context)
@ -369,14 +378,23 @@ class TransformerConnectionHandler(ConnectionHandler):
logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)") logger.info(f"rpc_inference.alloc(size={alloc_size / gib:.2f} GiB)")
yield handle yield handle
def _log_request(self, method: str, uids: Sequence[ModuleUID], context: P2PContext) -> None: def _log_request(
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid] self, method: str, uids: Optional[Sequence[ModuleUID]], context: P2PContext, *, warning: Optional[str] = None
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()] ) -> None:
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids if uids is not None:
friendly_uids = [uid.split(".")[-1] for uid in uids if "." in uid]
friendly_uids = [int(uid) for uid in friendly_uids if uid.isdigit()]
friendly_uids = f"{min(friendly_uids)}:{max(friendly_uids) + 1}" if friendly_uids else uids
else:
friendly_uids = "n/a"
friendly_remote_id = "..." + str(context.remote_id)[-6:] friendly_remote_id = "..." + str(context.remote_id)[-6:]
logger.info(f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})") message = f"{method}(blocks={friendly_uids}, remote_peer={friendly_remote_id})"
if warning is None:
logger.info(message)
else:
logger.warning(f"{message}: {warning}")
async def _rpc_forward( async def _rpc_forward(

@ -16,6 +16,8 @@ import hivemind
import torch import torch
from hivemind.utils import TensorDescriptor, get_logger from hivemind.utils import TensorDescriptor, get_logger
from petals.utils.asyncio import shield_and_wait
logger = get_logger(__file__) logger = get_logger(__file__)
Handle = int Handle = int
@ -66,28 +68,46 @@ class MemoryCache:
""" """
assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime" assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
assert descr.device is None and descr assert descr.device is None and descr
allocated_handle = None
allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8 alloc_size = descr.numel() * torch.finfo(descr.dtype).bits // 8
loop = asyncio.get_event_loop() alloc_task = asyncio.create_task(self._schedule_alloc(alloc_size, descr))
try: try:
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory): yield await shield_and_wait(alloc_task)
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
await loop.run_in_executor(
None, self._wait_until_available, allocated_size_bytes, self.alloc_timeout
)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pipe_send.send((allocated_handle, descr))
yield allocated_handle
finally: finally:
if allocated_handle is not None: await shield_and_wait(self._schedule_free(alloc_size, alloc_task))
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle async def _schedule_alloc(self, alloc_size: int, descr: TensorDescriptor) -> Handle:
self.current_size_bytes -= allocated_size_bytes """
self._memory_freed_event.set() This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
"""
loop = asyncio.get_event_loop()
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + alloc_size > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, alloc_size, self.alloc_timeout)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
handle = int(self.handle_counter)
self.current_size_bytes += alloc_size
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pipe_send.send((handle, descr))
return handle
async def _schedule_free(self, alloc_size: int, alloc_task: asyncio.Task):
"""
This method should be called inside asyncio.shield() because:
- hivemind.utils.enter_asynchronously() does not always release the lock on cancellation
- _schedule_free() must finish freeing memory even in case of cancellation
"""
if alloc_task.exception() is not None:
return
handle = alloc_task.result()
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
self._pipe_send.send((handle, None)) # signal runtime to free that handle
self.current_size_bytes -= alloc_size
self._memory_freed_event.set()
def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None): def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
# note: this function should only be called inside _lock_acquire_memory! # note: this function should only be called inside _lock_acquire_memory!

@ -0,0 +1,21 @@
import asyncio
async def shield_and_wait(task):
"""
Works like asyncio.shield(), but waits for the task to finish before raising CancelledError to the caller.
"""
if not isinstance(task, asyncio.Task):
task = asyncio.create_task(task)
cancel_exc = None
while True:
try:
result = await asyncio.shield(task)
break
except asyncio.CancelledError as e:
cancel_exc = e
if cancel_exc is not None:
raise cancel_exc
return result
Loading…
Cancel
Save