pull/434/head
Your Name 10 months ago
parent 458cf3339b
commit 9e63cff2c7

@ -10,11 +10,11 @@ import ctypes
import multiprocessing as mp
import os
import time
from typing import AsyncContextManager, Coroutine, Dict, Optional, Sequence
from typing import AsyncContextManager, Dict, Optional, Sequence
import async_timeout
import torch
from hivemind.utils import TensorDescriptor, anext, enter_asynchronously, get_logger
from hivemind.utils import TensorDescriptor, enter_asynchronously, get_logger
from petals.data_structures import Handle
from petals.utils.asyncio import shield_and_wait
@ -206,32 +206,5 @@ class MemoryCache:
yield tuple(self._allocated_tensors[handle] for handle in handles)
@contextlib.asynccontextmanager
async def wait_for_aenter(context: contextlib.AbstractAsyncContextManager, stop_when_completes: Coroutine):
"""Try to enter asynchronous context in time before stop_after task completes (e.g. timeout)"""
async def _enter_and_exit():
async with context:
yield
yield
async def _wait_for_deadline():
await stop_when_completes
aenter_and_aexit = _enter_and_exit()
aenter_task = asyncio.create_task(anext(aenter_and_aexit))
stop_task = asyncio.create_task(_wait_for_deadline())
try:
await asyncio.wait({aenter_task, stop_task}, return_when=asyncio.FIRST_COMPLETED)
if stop_task.done():
raise TimeoutError("Did not enter context in time")
yield aenter_task.result()
finally:
if aenter_task.done():
await anext(aenter_and_aexit) # exit normally
stop_task.cancel()
aenter_task.cancel()
class AllocationFailed(Exception):
pass

Loading…
Cancel
Save