|
|
@ -16,6 +16,7 @@ from hivemind.moe.server.connection_handler import ConnectionHandler
|
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
|
|
|
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
|
|
|
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
from hivemind.utils.streaming import split_for_streaming
|
|
|
|
|
|
|
|
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
|
from src.data_structures import CHAIN_DELIMITER, ModuleUID
|
|
|
@ -24,6 +25,8 @@ from src.server.task_pool import PrioritizedTaskPool
|
|
|
|
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
|
|
|
from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
|
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
from src.utils.misc import DUMMY, is_dummy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
|
|
|
"""Handles three request types: forward, backward and forward-incremental (inference)"""
|
|
|
@ -73,7 +76,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
print("OPENED RPC_INFERENCE")
|
|
|
|
logger.debug("Opened rpc_inference()")
|
|
|
|
request = await anext(requests)
|
|
|
|
request = await anext(requests)
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
requested_uids = self._check_uids(request.uid)
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
|
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
|
|
@ -164,7 +167,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|
|
|
prefix_length += hidden_states.shape[1]
|
|
|
|
prefix_length += hidden_states.shape[1]
|
|
|
|
request = await (anext(requests))
|
|
|
|
request = await (anext(requests))
|
|
|
|
finally:
|
|
|
|
finally:
|
|
|
|
print("CLOSED RPC_INFERENCE")
|
|
|
|
logger.debug("Closed rpc_inference()")
|
|
|
|
|
|
|
|
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
|
|
|
|
# Parse request and prepare backends
|
|
|
|
# Parse request and prepare backends
|
|
|
|