|
|
@ -2,9 +2,9 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
import asyncio
|
|
|
|
import contextlib
|
|
|
|
import contextlib
|
|
|
|
import multiprocessing.managers
|
|
|
|
import multiprocessing as mp
|
|
|
|
import sys
|
|
|
|
import sys
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
from enum import Enum
|
|
|
|
from itertools import chain
|
|
|
|
from itertools import chain
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
@ -42,20 +42,15 @@ logger = get_logger(__name__)
|
|
|
|
# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
|
|
|
|
# Fix pickling protobufs, see https://stackoverflow.com/a/74873028
|
|
|
|
sys.modules["runtime_pb2"] = runtime_pb2
|
|
|
|
sys.modules["runtime_pb2"] = runtime_pb2
|
|
|
|
|
|
|
|
|
|
|
|
# Fix queues in multiprocessing.Manager in Python < 3.9.7, see https://bugs.python.org/issue30256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_OriginalAutoProxy = multiprocessing.managers.AutoProxy
|
|
|
|
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patched_autoproxy(*args, manager_owned=True, **kwargs):
|
|
|
|
|
|
|
|
# Calling original AutoProxy without the unwanted key argument
|
|
|
|
|
|
|
|
return _OriginalAutoProxy(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
multiprocessing.managers.AutoProxy = patched_autoproxy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_TOKENS_AVAILABLE = "cache_tokens_available"
|
|
|
|
class Event(Enum):
|
|
|
|
|
|
|
|
NEW_SESSION = 0
|
|
|
|
|
|
|
|
END_SESSION = 1
|
|
|
|
|
|
|
|
PUSH = 2
|
|
|
|
|
|
|
|
SHUTDOWN = 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
@ -70,8 +65,8 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
*,
|
|
|
|
*,
|
|
|
|
adapters: Optional[Sequence[str]],
|
|
|
|
adapters: Optional[Sequence[str]],
|
|
|
|
dht_prefix: str,
|
|
|
|
dht_prefix: str,
|
|
|
|
push_manager: multiprocessing.managers.SyncManager,
|
|
|
|
handler_event_queues: Sequence[mp.Queue],
|
|
|
|
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue
|
|
|
|
handler_index: int,
|
|
|
|
inference_max_length: int,
|
|
|
|
inference_max_length: int,
|
|
|
|
request_timeout: float,
|
|
|
|
request_timeout: float,
|
|
|
|
session_timeout: float,
|
|
|
|
session_timeout: float,
|
|
|
@ -83,18 +78,28 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
|
assert isinstance(module_backend, TransformerBackend)
|
|
|
|
self.dht_prefix = dht_prefix
|
|
|
|
self.dht_prefix = dht_prefix
|
|
|
|
self.adapters = adapters
|
|
|
|
self.adapters = adapters
|
|
|
|
self._push_manager = push_manager
|
|
|
|
self._handler_event_queues = handler_event_queues
|
|
|
|
self._session_queues = session_queues
|
|
|
|
self._handler_index = handler_index
|
|
|
|
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues
|
|
|
|
self._own_event_queue = handler_event_queues[handler_index]
|
|
|
|
|
|
|
|
self._listener_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
|
|
self._session_queues: Dict[str, asyncio.Queue] = {}
|
|
|
|
|
|
|
|
self._session_handlers: Dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
|
|
self.inference_max_length = inference_max_length
|
|
|
|
self.inference_max_length = inference_max_length
|
|
|
|
self.request_timeout = request_timeout
|
|
|
|
self.request_timeout = request_timeout
|
|
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
self._prioritizer = task_prioritizer
|
|
|
|
self._prioritizer = task_prioritizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def add_p2p_handlers(self, *args, **kwargs) -> None:
|
|
|
|
|
|
|
|
if self._listener_task is None:
|
|
|
|
|
|
|
|
# Start listening to our own event queue before we accept any requests
|
|
|
|
|
|
|
|
self._listener_task = asyncio.create_task(self._listen_to_event_queue())
|
|
|
|
|
|
|
|
await super().add_p2p_handlers(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def shutdown(self):
|
|
|
|
def shutdown(self):
|
|
|
|
if self.is_alive():
|
|
|
|
if self.is_alive():
|
|
|
|
self._outer_pipe.send("_shutdown")
|
|
|
|
self._outer_pipe.send("_shutdown")
|
|
|
|
|
|
|
|
self._own_event_queue.put((Event.SHUTDOWN, None, None))
|
|
|
|
self.join(self.shutdown_timeout)
|
|
|
|
self.join(self.shutdown_timeout)
|
|
|
|
if self.is_alive():
|
|
|
|
if self.is_alive():
|
|
|
|
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
|
|
|
logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
|
|
|
@ -129,7 +134,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
context: P2PContext,
|
|
|
|
context: P2PContext,
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
|
|
|
|
|
|
|
|
async with timeout(self.session_timeout):
|
|
|
|
async with timeout(self.session_timeout):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
|
|
|
request = await asyncio.wait_for(anext(requests), self.step_timeout)
|
|
|
@ -146,7 +150,6 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
active_adapter = self._get_active_adapter(metadata)
|
|
|
|
active_adapter = self._get_active_adapter(metadata)
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
points = metadata.get("points", 0)
|
|
|
|
session_id = metadata.get("session_id")
|
|
|
|
session_id = metadata.get("session_id")
|
|
|
|
|
|
|
|
|
|
|
|
if not requested_uids:
|
|
|
|
if not requested_uids:
|
|
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
|
raise ValueError("User must specify at least one block for inference, but got none")
|
|
|
|
assert isinstance(
|
|
|
|
assert isinstance(
|
|
|
@ -235,6 +238,56 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
finally:
|
|
|
|
finally:
|
|
|
|
self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
|
self._log_request("rpc_inference.close", requested_uids, context)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
|
|
|
def _managed_session(self, session_id: str):
|
|
|
|
|
|
|
|
assert session_id not in self._session_queues, f"session id {session_id} is not unique"
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
self._session_queues[session_id] = asyncio.Queue()
|
|
|
|
|
|
|
|
self._session_handlers[session_id] = self._handler_index
|
|
|
|
|
|
|
|
for other_index, other_queue in enumerate(self._handler_event_queues):
|
|
|
|
|
|
|
|
if other_index != self._handler_index:
|
|
|
|
|
|
|
|
other_queue.put_nowait((Event.NEW_SESSION, session_id, self._handler_index))
|
|
|
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
|
|
self._session_queues.pop(session_id).put_nowait(None) # put None so that the get task will not hang
|
|
|
|
|
|
|
|
del self._session_handlers[session_id]
|
|
|
|
|
|
|
|
for other_index, other_queue in enumerate(self._handler_event_queues):
|
|
|
|
|
|
|
|
if other_index != self._handler_index:
|
|
|
|
|
|
|
|
other_queue.put_nowait((Event.END_SESSION, session_id, self._handler_index))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _put_into_session_queue(self, session_id: str, request: runtime_pb2.ExpertRequest):
|
|
|
|
|
|
|
|
handler_index = self._session_handlers.get(session_id)
|
|
|
|
|
|
|
|
if handler_index is None:
|
|
|
|
|
|
|
|
logger.debug(f"Ignored rpc_push to unknown session ID: {session_id}")
|
|
|
|
|
|
|
|
elif handler_index == self._handler_index:
|
|
|
|
|
|
|
|
self._session_queues[session_id].put_nowait(request)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self._handler_event_queues[handler_index].put_nowait((Event.PUSH, session_id, request))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_from_session_queue(self, session_id: str) -> Optional[runtime_pb2.ExpertRequest]:
|
|
|
|
|
|
|
|
assert self._session_handlers[session_id] == self._handler_index, "session belongs to another handler"
|
|
|
|
|
|
|
|
return await self._session_queues[session_id].get()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _listen_to_event_queue(self):
|
|
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
event, session_id, payload = await loop.run_in_executor(None, self._own_event_queue.get)
|
|
|
|
|
|
|
|
if event == Event.SHUTDOWN:
|
|
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
elif event == Event.NEW_SESSION:
|
|
|
|
|
|
|
|
self._session_handlers[session_id] = payload # index of the handler that owns that session
|
|
|
|
|
|
|
|
elif event == Event.END_SESSION:
|
|
|
|
|
|
|
|
self._session_handlers.pop(session_id, None)
|
|
|
|
|
|
|
|
elif event == Event.PUSH:
|
|
|
|
|
|
|
|
maybe_session_queue = self._session_queues.get(session_id)
|
|
|
|
|
|
|
|
if maybe_session_queue is not None:
|
|
|
|
|
|
|
|
maybe_session_queue.put_nowait(payload)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise RuntimeError(f"Unexpected event: {event}")
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
|
|
logger.exception(e)
|
|
|
|
|
|
|
|
|
|
|
|
async def _iterate_inference_steps(
|
|
|
|
async def _iterate_inference_steps(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
first_request: runtime_pb2.ExpertRequest,
|
|
|
|
first_request: runtime_pb2.ExpertRequest,
|
|
|
@ -243,67 +296,60 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
requested_uids: Sequence[str],
|
|
|
|
requested_uids: Sequence[str],
|
|
|
|
context: P2PContext,
|
|
|
|
context: P2PContext,
|
|
|
|
) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
|
|
|
|
) -> AsyncIterator[Tuple[runtime_pb2.ExpertRequest, dict]]:
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
|
|
|
push_queue = self._push_manager.Queue()
|
|
|
|
|
|
|
|
self._session_queues[session_id] = push_queue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processed_step_ids = set()
|
|
|
|
processed_step_ids = set()
|
|
|
|
n_pushes = n_late_pushes = 0
|
|
|
|
n_pushes = n_late_pushes = 0
|
|
|
|
request = first_request
|
|
|
|
request = first_request
|
|
|
|
anext_task = get_push_task = None
|
|
|
|
anext_task = get_push_task = None
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
|
with self._managed_session(session_id) if session_id is not None else contextlib.nullcontext():
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
while request.tensors: # iterate while user is willing to supply tensors
|
|
|
|
step_id = metadata.get("step_id")
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
|
|
|
|
step_id = metadata.get("step_id")
|
|
|
|
pushed = metadata.get("pushed")
|
|
|
|
|
|
|
|
if pushed:
|
|
|
|
pushed = metadata.get("pushed")
|
|
|
|
n_pushes += 1
|
|
|
|
if pushed:
|
|
|
|
|
|
|
|
n_pushes += 1
|
|
|
|
if step_id is None or step_id not in processed_step_ids:
|
|
|
|
self._log_request("rpc_inference.push", requested_uids, context, debug=f"session received push")
|
|
|
|
yield request, metadata
|
|
|
|
|
|
|
|
if step_id is not None:
|
|
|
|
if step_id is None or step_id not in processed_step_ids:
|
|
|
|
processed_step_ids.add(step_id)
|
|
|
|
yield request, metadata
|
|
|
|
elif pushed:
|
|
|
|
if step_id is not None:
|
|
|
|
n_late_pushes += 1
|
|
|
|
processed_step_ids.add(step_id)
|
|
|
|
self._log_request(
|
|
|
|
elif pushed:
|
|
|
|
"rpc_inference.push",
|
|
|
|
n_late_pushes += 1
|
|
|
|
requested_uids,
|
|
|
|
self._log_request(
|
|
|
|
context,
|
|
|
|
"rpc_inference.push",
|
|
|
|
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
|
|
|
|
requested_uids,
|
|
|
|
|
|
|
|
context,
|
|
|
|
|
|
|
|
warning=f"arrived late {n_late_pushes / n_pushes * 100:.1f}% of the time",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Wait for the next request, coming either from the `requests` iterator or `push_queue`
|
|
|
|
|
|
|
|
if anext_task is None:
|
|
|
|
|
|
|
|
anext_task = asyncio.create_task(anext(requests))
|
|
|
|
|
|
|
|
if get_push_task is None:
|
|
|
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
|
|
|
get_push_task = asyncio.create_task(self._get_from_session_queue(session_id))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task
|
|
|
|
|
|
|
|
done, _ = await asyncio.wait(
|
|
|
|
|
|
|
|
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Wait for the next request, coming either from the `requests` iterator or `push_queue`
|
|
|
|
if anext_task in done:
|
|
|
|
if anext_task is None:
|
|
|
|
request = await anext_task
|
|
|
|
anext_task = asyncio.create_task(anext(requests))
|
|
|
|
anext_task = None
|
|
|
|
if get_push_task is None:
|
|
|
|
elif get_push_task in done:
|
|
|
|
if session_id is not None:
|
|
|
|
request = await get_push_task
|
|
|
|
get_push_task = loop.run_in_executor(self._executor, push_queue.get)
|
|
|
|
get_push_task = None
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
get_push_task = asyncio.create_task(asyncio.Event().wait()) # Dummy never-ending task
|
|
|
|
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
|
|
|
|
done, _ = await asyncio.wait(
|
|
|
|
anext_task.cancel()
|
|
|
|
[anext_task, get_push_task], timeout=self.step_timeout, return_when=asyncio.FIRST_COMPLETED
|
|
|
|
get_push_task.cancel()
|
|
|
|
)
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
if anext_task in done:
|
|
|
|
|
|
|
|
request = await anext_task
|
|
|
|
|
|
|
|
anext_task = None
|
|
|
|
|
|
|
|
elif get_push_task in done:
|
|
|
|
|
|
|
|
request = await get_push_task
|
|
|
|
|
|
|
|
get_push_task = None
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self._log_request("rpc_inference.step", requested_uids, context, warning="timed out")
|
|
|
|
|
|
|
|
anext_task.cancel()
|
|
|
|
|
|
|
|
get_push_task.cancel()
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
except:
|
|
|
|
except:
|
|
|
|
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
|
|
|
|
logger.warning("rpc_inference._iterate_inference_steps() exception:", exc_info=True)
|
|
|
|
raise
|
|
|
|
raise
|
|
|
|
finally:
|
|
|
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
|
|
|
push_queue.put(None) # Stop thread for get_push_task
|
|
|
|
|
|
|
|
del self._session_queues[session_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
"""Directly push activation tensors from one server to another"""
|
|
|
|
"""Directly push activation tensors from one server to another"""
|
|
|
@ -312,8 +358,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata)
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata)
|
|
|
|
session_id = metadata["session_id"]
|
|
|
|
session_id = metadata["session_id"]
|
|
|
|
self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
|
|
|
|
self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}")
|
|
|
|
|
|
|
|
self._put_into_session_queue(session_id, request)
|
|
|
|
self._session_queues[session_id].put(request)
|
|
|
|
|
|
|
|
return runtime_pb2.ExpertResponse()
|
|
|
|
return runtime_pb2.ExpertResponse()
|
|
|
|
|
|
|
|
|
|
|
|
async def _push_outputs(
|
|
|
|
async def _push_outputs(
|
|
|
|