From 74c086ea3546a9f0ed9bb4d7305ff80bf34ebf15 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Wed, 19 Jul 2023 04:27:45 +0300 Subject: [PATCH] DO NOT MERGE UNDER ANY CIRCUMSTANCES --- src/petals/server/handler.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d0531de..c3ed9cb 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -129,6 +129,7 @@ class TransformerConnectionHandler(ConnectionHandler): context: P2PContext, ) -> AsyncIterator[runtime_pb2.ExpertResponse]: """Compute a single step of inference using attention cache; update attention cache accordingly.""" + import os, psutil; print(f"handler rpc_inference {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb") async with timeout(self.session_timeout): try: @@ -304,18 +305,21 @@ class TransformerConnectionHandler(ConnectionHandler): if session_id is not None: push_queue.put(None) # Stop thread for get_push_task del self._session_queues[session_id] + print("DELETED SESSION", session_id, flush=True) async def rpc_push(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: """Directly push activation tensors from one server to another""" + import os, psutil; print(f"handler rpc_push {os.getpid()} : {psutil.Process().memory_info().rss / 1024 / 1024} mb") + try: + requested_uids = self._check_uids(request.uid) + metadata = MSGPackSerializer.loads(request.metadata) + session_id = metadata["session_id"] + self._log_request("rpc_push", requested_uids, context, warning=f"session_id={session_id}") - requested_uids = self._check_uids(request.uid) - metadata = MSGPackSerializer.loads(request.metadata) - session_id = metadata["session_id"] - self._log_request("rpc_push", requested_uids, context, debug=f"session_id={session_id}") - - self._session_queues[session_id].put(request) - return runtime_pb2.ExpertResponse() - + self._session_queues[session_id].put(request) + return runtime_pb2.ExpertResponse() + except Exception as e: + logger.exception(e) async def _push_outputs( self, request: runtime_pb2.ExpertRequest, serialized_outputs: runtime_pb2.Tensor, metadata: dict ) -> None: