|
|
|
@ -2,6 +2,7 @@ from concurrent.futures import Future
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import List, Optional, Union, Sequence
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.moe.client import RemoteExpert
|
|
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
|
|
@ -10,18 +11,30 @@ from hivemind.p2p import StubBase, P2P
|
|
|
|
|
from hivemind.proto.runtime_pb2 import ExpertInfo
|
|
|
|
|
from hivemind.dht import DHT
|
|
|
|
|
from hivemind.utils import MPFuture, DHTExpiration
|
|
|
|
|
|
|
|
|
|
from src import DistributedBloomConfig
|
|
|
|
|
from src.server.backend import MAX_LENGTH
|
|
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteTransformerBlock(RemoteExpert):
|
|
|
|
|
class RemoteTransformerBlockSession(RemoteExpert):
|
|
|
|
|
"""A class that interacts with a specific remote server for forward/backward or inference"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
|
|
|
|
|
super().__init__(info, p2p)
|
|
|
|
|
self._config = config
|
|
|
|
|
self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
|
|
|
|
|
self._active_stream: Optional[RemoteTransformerStream] = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def stub(self) -> StubBase:
|
|
|
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_remote_module(
|
|
|
|
|
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
|
|
|
|
|
) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
|
|
|
|
|
) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
|
|
|
|
|
"""
|
|
|
|
|
:param uids: find experts with these ids from across the DHT
|
|
|
|
|
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
|
|
@ -35,7 +48,7 @@ def get_remote_module(
|
|
|
|
|
|
|
|
|
|
def create_remote_module(
|
|
|
|
|
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
|
|
|
) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
|
|
|
|
|
) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
|
|
|
|
|
if return_future:
|
|
|
|
|
|
|
|
|
|
async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
|
@ -48,10 +61,10 @@ def create_remote_module(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
|
|
|
|
|
experts: List[Optional[RemoteTransformerBlock]] = []
|
|
|
|
|
experts: List[Optional[RemoteTransformerBlockSession]] = []
|
|
|
|
|
for info in infos:
|
|
|
|
|
if info is not None:
|
|
|
|
|
experts.append(RemoteTransformerBlock(info, p2p))
|
|
|
|
|
experts.append(RemoteTransformerBlockSession(info, p2p))
|
|
|
|
|
else:
|
|
|
|
|
experts.append(None)
|
|
|
|
|
return experts
|
|
|
|
|