From 5a8de2f1f8173bce927381eb575d83e91dc90315 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 19 Jul 2023 12:31:47 +0300 Subject: [PATCH] Fix handler memory leak, get rid of mp.Manager (#373) This PR removes the memory leak from somewhere within handler.py that has something to do with mp.SyncManager. --- src/petals/server/handler.py | 189 ++++++++++++++++++++++------------- src/petals/server/server.py | 11 +- 2 files changed, 121 insertions(+), 79 deletions(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d0531de..5d0a3d4 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -2,9 +2,9 @@ from __future__ import annotations import asyncio import contextlib -import multiprocessing.managers +import multiprocessing as mp import sys -from concurrent.futures import ThreadPoolExecutor +from enum import Enum from itertools import chain from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union @@ -42,20 +42,15 @@ logger = get_logger(__name__) # Fix pickling protobufs, see https://stackoverflow.com/a/74873028 sys.modules["runtime_pb2"] = runtime_pb2 -# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256 -_OriginalAutoProxy = multiprocessing.managers.AutoProxy - - -def patched_autoproxy(*args, manager_owned=True, **kwargs): - # Calling original AutoProxy without the unwanted key argument - return _OriginalAutoProxy(*args, **kwargs) - - -multiprocessing.managers.AutoProxy = patched_autoproxy +CACHE_TOKENS_AVAILABLE = "cache_tokens_available" -CACHE_TOKENS_AVAILABLE = "cache_tokens_available" +class Event(Enum): + NEW_SESSION = 0 + END_SESSION = 1 + PUSH = 2 + SHUTDOWN = 3 class TransformerConnectionHandler(ConnectionHandler): @@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler): *, adapters: Optional[Sequence[str]], dht_prefix: str, - push_manager: multiprocessing.managers.SyncManager, - session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue + handler_event_queues: Sequence[mp.Queue], + handler_index: int, inference_max_length: int, request_timeout: float, session_timeout: float, @@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix self.adapters = adapters - self._push_manager = push_manager - self._session_queues = session_queues - self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues + self._handler_event_queues = handler_event_queues + self._handler_index = handler_index + self._own_event_queue = handler_event_queues[handler_index] + self._listener_task: Optional[asyncio.Task] = None + self._session_queues: Dict[str, asyncio.Queue] = {} + self._session_handlers: Dict[str, int] = {} self.inference_max_length = inference_max_length self.request_timeout = request_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout self._prioritizer = task_prioritizer + async def add_p2p_handlers(self, *args, **kwargs) -> None: + if self._listener_task is None: + # Start listening to our own event queue before we accept any requests + self._listener_task = asyncio.create_task(self._listen_to_event_queue()) + await super().add_p2p_handlers(*args, **kwargs) + def shutdown(self): if self.is_alive(): self._outer_pipe.send("_shutdown") + self._own_event_queue.put((Event.SHUTDOWN, None, None)) self.join(self.shutdown_timeout) if self.is_alive(): logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") @@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler): context: P2PContext, ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" - async with timeout(self.session_timeout): try: request = await asyncio.wait_for(anext(requests), self.step_timeout) @@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler): active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") - if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler): finally: self._log_request("rpc_inference.close", requested_uids, context) + @contextlib.contextmanager + def _managed_session(self, session_id: str): + assert session_id not in self._session_queues, f"session id {session_id} is not unique" + try: + self._session_queues[session_id] = asyncio.Queue() + self._session_handlers[session_id] = self._handler_index + for other_index, other_queue in enumerate(self._handler_event_queues): + if other_index != self._handler_index: + other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index)) + yield + finally: + self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang + del self._session_handlers[session_id] + for other_index, other_queue in enumerate(self._handler_event_queues): + if other_index != self._handler_index: + other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index)) + + def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest): + handler_index = self._session_handlers.get(session_id) + if handler_index is None: + logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}") + elif handler_index == self._handler_index: + self._session_queues[session_id].put_nowait(request) + else: + self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request)) + + async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]: + assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler" + return await self._session_queues[session_id].get() + + async def _listen_to_event_queue(self): + loop = asyncio.get_event_loop() + while True: + try: + event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get) + if event == Event.SHUTDOWN: + break + elif event == Event.NEW_SESSION: + self._session_handlers[session_id] = payload # index of the handler that owns that session + elif event == Event.END_SESSION: + self._session_handlers.pop(session_id, None) + elif event == Event.PUSH: + maybe_session_queue = self._session_queues.get(session_id) + if maybe_session_queue is not None: + maybe_session_queue.put_nowait(payload) + else: + raise RuntimeError(f"Unexpected event: {event}") + except Exception as e: + logger.exception(e) + async def _iterate_inference_steps( self, first_request: runtime_pb2.ExpertRequest, @@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler): requested_uids: Sequence[str], context: P2PContext, ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: - loop = asyncio.get_event_loop() - if session_id is not None: - push_queue = self._push_manager.Queue() - self._session_queues[session_id] = push_queue - processed_step_ids = set() n_pushes = n_late_pushes = 0 request = first_request anext_task = get_push_task = None try: - while request.tensors: # iterate while user is willing to supply tensors - metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - step_id = metadata.get("step_id") - - pushed = metadata.get("pushed") - if pushed: - n_pushes += 1 - - if step_id is None or step_id not in processed_step_ids: - yield request, metadata - if step_id is not None: - processed_step_ids.add(step_id) - elif pushed: - n_late_pushes += 1 - self._log_request( - "rpc_inference.push", - requested_uids, - context, - warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext(): + while request.tensors: # iterate while user is willing to supply tensors + metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + step_id = metadata.get("step_id") + + pushed = metadata.get("pushed") + if pushed: + n_pushes += 1 + self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push") + + if step_id is None or step_id not in processed_step_ids: + yield request, metadata + if step_id is not None: + processed_step_ids.add(step_id) + elif pushed: + n_late_pushes += 1 + self._log_request( + "rpc_inference.push", + requested_uids, + context, + warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", + ) + + # Wait for the next request, coming either from the `requests` iterator or `push_queue` + if anext_task is None: + anext_task = asyncio.create_task(anext(requests)) + if get_push_task is None: + if session_id is not None: + get_push_task = asyncio.create_task(self._get_from_session_queue(session_id)) + else: + get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task + done, _ = await asyncio.wait( + [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED ) - # Wait for the next request, coming either from the `requests` iterator or `push_queue` - if anext_task is None: - anext_task = asyncio.create_task(anext(requests)) - if get_push_task is None: - if session_id is not None: - get_push_task = loop.run_in_executor(self._executor, push_queue.get) + if anext_task in done: + request = await anext_task + anext_task = None + elif get_push_task in done: + request = await get_push_task + get_push_task = None else: - get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task - done, _ = await asyncio.wait( - [anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED - ) - - if anext_task in done: - request = await anext_task - anext_task = None - elif get_push_task in done: - request = await get_push_task - get_push_task = None - else: - self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") - anext_task.cancel() - get_push_task.cancel() - return + self._log_request("rpc_inference.step", requested_uids, context, warning="timed out") + anext_task.cancel() + get_push_task.cancel() + return except: logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) raise - finally: - if session_id is not None: - push_queue.put(None) # Stop thread for get_push_task - del self._session_queues[session_id] async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: """Directly push activation tensors from one server to another""" @@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) session_id = metadata["session_id"] self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") - - self._session_queues[session_id].put(request) + self._put_into_session_queue(session_id, request) return runtime_pb2.ExpertResponse() async def _push_outputs( diff --git a/src/petals/server/server.py b/src/petals/server/server.py index d061d0a..72db9ce 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -528,23 +528,21 @@ class ModuleContainer(threading.Thread): self.dht, self.module_backends = dht, module_backends self.server_info, self.update_period, self.expiration = server_info, update_period, expiration - self.push_manager = mp.Manager() - self.push_manager.__enter__() - session_queues = self.push_manager.dict() + handler_event_queues = [mp.Queue() for _ in range(num_handlers)] self.conn_handlers = [ TransformerConnectionHandler( dht, self.module_backends, adapters=server_info.adapters, dht_prefix=dht_prefix, - push_manager=self.push_manager, - session_queues=session_queues, + handler_event_queues=handler_event_queues, + handler_index=i, inference_max_length=inference_max_length, request_timeout=request_timeout, session_timeout=session_timeout, step_timeout=step_timeout, ) - for _ in range(num_handlers) + for i in range(num_handlers) ] self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) @@ -607,7 +605,6 @@ class ModuleContainer(threading.Thread): logger.debug("Shutting down connection handlers") for handler in self.conn_handlers: handler.shutdown() - self.push_manager.__exit__(None, None, None) logger.debug(f"Shutting down pools") for pool in self.runtime.pools: