|
|
|
@ -1,38 +1,45 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Union
|
|
|
|
|
import random
|
|
|
|
|
from typing import Any, AsyncIterator, Dict, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
|
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
|
|
|
|
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
|
|
|
|
|
from hivemind.moe.expert_uid import ExpertInfo as RemoteModuleInfo
|
|
|
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
|
|
|
from hivemind.p2p import P2P, PeerID, StubBase
|
|
|
|
|
from hivemind.moe.client.expert import RemoteExpertWorker, RemoteExpert
|
|
|
|
|
from hivemind.moe.expert_uid import ExpertInfo
|
|
|
|
|
from hivemind.p2p import P2P, StubBase
|
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
|
from hivemind.utils import DHTExpiration, MPFuture, anext, as_aiter, get_dht_time, nested_flatten
|
|
|
|
|
from hivemind.utils import anext, nested_flatten
|
|
|
|
|
|
|
|
|
|
from src.dht_utils import ModuleUID
|
|
|
|
|
from src.data_structures import RemoteModuleInfo
|
|
|
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteTransformerBlock(RemoteExpert):
|
|
|
|
|
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
|
|
|
|
|
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) #TODO replace this
|
|
|
|
|
super().__init__(peer_info, p2p)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def stub(self) -> StubBase:
|
|
|
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
|
|
|
|
|
|
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
|
|
|
|
|
"""Initialize a new inference session with the specified remote server"""
|
|
|
|
|
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
|
|
|
|
|
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator
|
|
|
|
|
):
|
|
|
|
|
self.uid, self.info = uid, info
|
|
|
|
|
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
|
|
|
|
|
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
|
|
|
|
@ -113,55 +120,3 @@ class RemoteTransformerBlockInferenceSession:
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_remote_module(
|
|
|
|
|
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
|
|
|
|
|
) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
|
|
|
|
|
"""
|
|
|
|
|
:param uids: find experts with these ids from across the DHT
|
|
|
|
|
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
|
|
|
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
|
|
:returns: a list of [RemoteTransformerBlock if found else None]
|
|
|
|
|
"""
|
|
|
|
|
assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
|
|
|
|
|
infos = dht.run_coroutine(
|
|
|
|
|
partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time), return_future
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if return_future:
|
|
|
|
|
|
|
|
|
|
async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
|
|
p2p = await dht.replicate_p2p()
|
|
|
|
|
return _create_remote_modules_from_infos(await infos_future, p2p)
|
|
|
|
|
|
|
|
|
|
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
|
|
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
|
|
return _create_remote_modules_from_infos(infos, p2p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _get_remote_module_infos(
|
|
|
|
|
dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
|
|
|
|
|
) -> List[Optional[RemoteModuleInfo]]:
|
|
|
|
|
if expiration_time is None:
|
|
|
|
|
expiration_time = get_dht_time()
|
|
|
|
|
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
|
|
|
|
found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
|
|
|
|
|
|
|
|
|
|
experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
|
|
|
|
|
for i, uid in enumerate(uids):
|
|
|
|
|
server_peer_id = found[uid]
|
|
|
|
|
if server_peer_id is not None and isinstance(server_peer_id.value, str):
|
|
|
|
|
experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
|
|
|
|
|
return experts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_remote_modules_from_infos(
|
|
|
|
|
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
|
|
|
|
) -> List[Optional[RemoteTransformerBlock]]:
|
|
|
|
|
experts: List[Optional[RemoteTransformerBlock]] = []
|
|
|
|
|
for info in infos:
|
|
|
|
|
if info is not None:
|
|
|
|
|
experts.append(RemoteTransformerBlock(info, p2p))
|
|
|
|
|
else:
|
|
|
|
|
experts.append(None)
|
|
|
|
|
return experts
|
|
|
|
|