DO NOT MERGE UNDER ANY CIRCUMSTANCES

This commit is contained in:
Aleksandr Borzunov 2023-07-19 04:27:45 +03:00 committed by Your Name
parent c735dd7ba3
commit 74c086ea35

View File

@ -129,6 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler):
context: P2PContext, context: P2PContext,
) -> AsyncIterator[runtime_pb2.ExpertResponse]: ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
"""Compute a single step of inference using attention cache; update attention cache accordingly.""" """Compute a single step of inference using attention cache; update attention cache accordingly."""
import os, psutil; print(f"handler rpc_inference {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb")
async with timeout(self.session_timeout): async with timeout(self.session_timeout):
try: try:
@ -304,18 +305,21 @@ class TransformerConnectionHandler(ConnectionHandler):
if session_id is not None: if session_id is not None:
push_queue.put(None) # Stop thread for get_push_task push_queue.put(None) # Stop thread for get_push_task
del self._session_queues[session_id] del self._session_queues[session_id]
print("DELETED SESSION", session_id, flush=True)
async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
"""Directly push activation tensors from one server to another""" """Directly push activation tensors from one server to another"""
import os, psutil; print(f"handler rpc_push {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb")
try:
requested_uids = self._check_uids(request.uid) requested_uids = self._check_uids(request.uid)
metadata = MSGPackSerializer.loads(request.metadata) metadata = MSGPackSerializer.loads(request.metadata)
session_id = metadata["session_id"] session_id = metadata["session_id"]
self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") self._log_request("rpc_push", requested_uids, context, warning=f"session_id={session_id}")
self._session_queues[session_id].put(request) self._session_queues[session_id].put(request)
return runtime_pb2.ExpertResponse() return runtime_pb2.ExpertResponse()
except Exception as e:
logger.exception(e)
async def _push_outputs( async def _push_outputs(
self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict
) -> None: ) -> None: