Fix handler memory leak, get rid of mp.Manager (#373)

This PR removes the memory leak from somewhere within handler.py that has something to do with mp.SyncManager.
pull/378/head
justheuristic 11 months ago committed by GitHub
parent 895327a0ae
commit 5a8de2f1f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,9 +2,9 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import multiprocessing.managers import multiprocessing as mp
import sys import sys
from concurrent.futures import ThreadPoolExecutor from enum import Enum
from itertools import chain from itertools import chain
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
@ -42,20 +42,15 @@ logger = get_logger(__name__)
# Fix pickling protobufs, see https://stackoverflow.com/a/74873028 # Fix pickling protobufs, see https://stackoverflow.com/a/74873028
sys.modules["runtime_pb2"] = runtime_pb2 sys.modules["runtime_pb2"] = runtime_pb2
# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
_OriginalAutoProxy = multiprocessing.managers.AutoProxy CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
def patched_autoproxy(*args, manager_owned=True, **kwargs):
# Calling original AutoProxy without the unwanted key argument
return _OriginalAutoProxy(*args, **kwargs)
multiprocessing.managers.AutoProxy = patched_autoproxy
CACHE_TOKENS_AVAILABLE = "cache_tokens_available" class Event(Enum):
NEW_SESSION = 0
END_SESSION = 1
PUSH = 2
SHUTDOWN = 3
class TransformerConnectionHandler(ConnectionHandler): class TransformerConnectionHandler(ConnectionHandler):
@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler):
*, *,
adapters: Optional[Sequence[str]], adapters: Optional[Sequence[str]],
dht_prefix: str, dht_prefix: str,
push_manager: multiprocessing.managers.SyncManager, handler_event_queues: Sequence[mp.Queue],
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue handler_index: int,
inference_max_length: int, inference_max_length: int,
request_timeout: float, request_timeout: float,
session_timeout: float, session_timeout: float,
@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler):
assert isinstance(module_backend, TransformerBackend) assert isinstance(module_backend, TransformerBackend)
self.dht_prefix = dht_prefix self.dht_prefix = dht_prefix
self.adapters = adapters self.adapters = adapters
self._push_manager = push_manager self._handler_event_queues = handler_event_queues
self._session_queues = session_queues self._handler_index = handler_index
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues self._own_event_queue = handler_event_queues[handler_index]
self._listener_task: Optional[asyncio.Task] = None
self._session_queues: Dict[str, asyncio.Queue] = {}
self._session_handlers: Dict[str, int] = {}
self.inference_max_length = inference_max_length self.inference_max_length = inference_max_length
self.request_timeout = request_timeout self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer self._prioritizer = task_prioritizer
async def add_p2p_handlers(self, *args, **kwargs) -> None:
if self._listener_task is None:
# Start listening to our own event queue before we accept any requests
self._listener_task = asyncio.create_task(self._listen_to_event_queue())
await super().add_p2p_handlers(*args, **kwargs)
def shutdown(self): def shutdown(self):
if self.is_alive(): if self.is_alive():
self._outer_pipe.send("_shutdown") self._outer_pipe.send("_shutdown")
self._own_event_queue.put((Event.SHUTDOWN, None, None))
self.join(self.shutdown_timeout) self.join(self.shutdown_timeout)
if self.is_alive(): if self.is_alive():
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM") logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler):
context: P2PContext, context: P2PContext,
) -> AsyncIterator[runtime_pb2.ExpertResponse]: ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
"""Compute a single step of inference using attention cache; update attention cache accordingly.""" """Compute a single step of inference using attention cache; update attention cache accordingly."""
async with timeout(self.session_timeout): async with timeout(self.session_timeout):
try: try:
request = await asyncio.wait_for(anext(requests), self.step_timeout) request = await asyncio.wait_for(anext(requests), self.step_timeout)
@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler):
active_adapter = self._get_active_adapter(metadata) active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0) points = metadata.get("points", 0)
session_id = metadata.get("session_id") session_id = metadata.get("session_id")
if not requested_uids: if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none") raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance( assert isinstance(
@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler):
finally: finally:
self._log_request("rpc_inference.close", requested_uids, context) self._log_request("rpc_inference.close", requested_uids, context)
@contextlib.contextmanager
def _managed_session(self, session_id: str):
assert session_id not in self._session_queues, f"session id {session_id} is not unique"
try:
self._session_queues[session_id] = asyncio.Queue()
self._session_handlers[session_id] = self._handler_index
for other_index, other_queue in enumerate(self._handler_event_queues):
if other_index != self._handler_index:
other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))
yield
finally:
self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang
del self._session_handlers[session_id]
for other_index, other_queue in enumerate(self._handler_event_queues):
if other_index != self._handler_index:
other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))
def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):
handler_index = self._session_handlers.get(session_id)
if handler_index is None:
logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}")
elif handler_index == self._handler_index:
self._session_queues[session_id].put_nowait(request)
else:
self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))
async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:
assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler"
return await self._session_queues[session_id].get()
async def _listen_to_event_queue(self):
loop = asyncio.get_event_loop()
while True:
try:
event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)
if event == Event.SHUTDOWN:
break
elif event == Event.NEW_SESSION:
self._session_handlers[session_id] = payload # index of the handler that owns that session
elif event == Event.END_SESSION:
self._session_handlers.pop(session_id, None)
elif event == Event.PUSH:
maybe_session_queue = self._session_queues.get(session_id)
if maybe_session_queue is not None:
maybe_session_queue.put_nowait(payload)
else:
raise RuntimeError(f"Unexpected event: {event}")
except Exception as e:
logger.exception(e)
async def _iterate_inference_steps( async def _iterate_inference_steps(
self, self,
first_request: runtime_pb2.ExpertRequest, first_request: runtime_pb2.ExpertRequest,
@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_uids: Sequence[str], requested_uids: Sequence[str],
context: P2PContext, context: P2PContext,
) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]: ) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
loop = asyncio.get_event_loop()
if session_id is not None:
push_queue = self._push_manager.Queue()
self._session_queues[session_id] = push_queue
processed_step_ids = set() processed_step_ids = set()
n_pushes = n_late_pushes = 0 n_pushes = n_late_pushes = 0
request = first_request request = first_request
anext_task = get_push_task = None anext_task = get_push_task = None
try: try:
while request.tensors: # iterate while user is willing to supply tensors with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} while request.tensors: # iterate while user is willing to supply tensors
step_id = metadata.get("step_id") metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
step_id = metadata.get("step_id")
pushed = metadata.get("pushed")
if pushed: pushed = metadata.get("pushed")
n_pushes += 1 if pushed:
n_pushes += 1
if step_id is None or step_id not in processed_step_ids: self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push")
yield request, metadata
if step_id is not None: if step_id is None or step_id not in processed_step_ids:
processed_step_ids.add(step_id) yield request, metadata
elif pushed: if step_id is not None:
n_late_pushes += 1 processed_step_ids.add(step_id)
self._log_request( elif pushed:
"rpc_inference.push", n_late_pushes += 1
requested_uids, self._log_request(
context, "rpc_inference.push",
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time", requested_uids,
context,
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
)
# Wait for the next request, coming either from the `requests` iterator or `push_queue`
if anext_task is None:
anext_task = asyncio.create_task(anext(requests))
if get_push_task is None:
if session_id is not None:
get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))
else:
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task
done, _ = await asyncio.wait(
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
) )
# Wait for the next request, coming either from the `requests` iterator or `push_queue` if anext_task in done:
if anext_task is None: request = await anext_task
anext_task = asyncio.create_task(anext(requests)) anext_task = None
if get_push_task is None: elif get_push_task in done:
if session_id is not None: request = await get_push_task
get_push_task = loop.run_in_executor(self._executor, push_queue.get) get_push_task = None
else: else:
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
done, _ = await asyncio.wait( anext_task.cancel()
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED get_push_task.cancel()
) return
if anext_task in done:
request = await anext_task
anext_task = None
elif get_push_task in done:
request = await get_push_task
get_push_task = None
else:
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
anext_task.cancel()
get_push_task.cancel()
return
except: except:
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True) logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
raise raise
finally:
if session_id is not None:
push_queue.put(None) # Stop thread for get_push_task
del self._session_queues[session_id]
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
"""Directly push activation tensors from one server to another""" """Directly push activation tensors from one server to another"""
@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) metadata = MSGPackSerializer.loads(request.metadata)
session_id = metadata["session_id"] session_id = metadata["session_id"]
self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
self._put_into_session_queue(session_id, request)
self._session_queues[session_id].put(request)
return runtime_pb2.ExpertResponse() return runtime_pb2.ExpertResponse()
async def _push_outputs( async def _push_outputs(

@ -528,23 +528,21 @@ class ModuleContainer(threading.Thread):
self.dht, self.module_backends = dht, module_backends self.dht, self.module_backends = dht, module_backends
self.server_info, self.update_period, self.expiration = server_info, update_period, expiration self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
self.push_manager = mp.Manager() handler_event_queues = [mp.Queue() for _ in range(num_handlers)]
self.push_manager.__enter__()
session_queues = self.push_manager.dict()
self.conn_handlers = [ self.conn_handlers = [
TransformerConnectionHandler( TransformerConnectionHandler(
dht, dht,
self.module_backends, self.module_backends,
adapters=server_info.adapters, adapters=server_info.adapters,
dht_prefix=dht_prefix, dht_prefix=dht_prefix,
push_manager=self.push_manager, handler_event_queues=handler_event_queues,
session_queues=session_queues, handler_index=i,
inference_max_length=inference_max_length, inference_max_length=inference_max_length,
request_timeout=request_timeout, request_timeout=request_timeout,
session_timeout=session_timeout, session_timeout=session_timeout,
step_timeout=step_timeout, step_timeout=step_timeout,
) )
for _ in range(num_handlers) for i in range(num_handlers)
] ]
self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs) self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
@ -607,7 +605,6 @@ class ModuleContainer(threading.Thread):
logger.debug("Shutting down connection handlers") logger.debug("Shutting down connection handlers")
for handler in self.conn_handlers: for handler in self.conn_handlers:
handler.shutdown() handler.shutdown()
self.push_manager.__exit__(None, None, None)
logger.debug(f"Shutting down pools") logger.debug(f"Shutting down pools")
for pool in self.runtime.pools: for pool in self.runtime.pools:

Loading…
Cancel
Save