black-isort

fix-auth-token
justheuristic 2 years ago
parent 0e7afea026
commit 01b9bced78

@ -1,3 +1,3 @@
from .bloom import *
from .dht_utils import get_remote_module, declare_active_modules
from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
from .dht_utils import declare_active_modules, get_remote_module

@ -6,14 +6,14 @@ from typing import Any, AsyncIterator, Dict, Optional
import torch
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.expert import RemoteExpertWorker, RemoteExpert
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import anext, nested_flatten
from src.dht_utils import ModuleUID
from src.data_structures import RemoteModuleInfo
from src.dht_utils import ModuleUID
from src.server.handler import TransformerConnectionHandler
@ -21,7 +21,7 @@ class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) #TODO replace this
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) # TODO replace this
super().__init__(peer_info, p2p)
@property
@ -37,9 +37,7 @@ class RemoteTransformerBlock(RemoteExpert):
class RemoteTransformerBlockInferenceSession:
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
def __init__(
self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator
):
def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
self.uid, self.info = uid, info
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
@ -119,4 +117,3 @@ class RemoteTransformerBlockInferenceSession:
def __exit__(self, *exc_details):
self.close()

@ -8,25 +8,29 @@ from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, use_hivemind_log_handler, get_logger
from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
from src.data_structures import RemoteModuleInfo, ModuleUID, UID_DELIMITER
from src.data_structures import UID_DELIMITER, ModuleUID, RemoteModuleInfo
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
def declare_active_modules(
dht: DHT, uids: Sequence[ModuleUID], expiration_time: DHTExpiration, wait: bool = True
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: DHTExpiration,
throughput: Optional[float] = None,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: optionally specify your performance in terms of compute throughput
:param expiration_time: declated modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
@ -37,25 +41,33 @@ def declare_active_modules(
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time), return_future=not wait
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
return_future=not wait,
)
async def _declare_active_modules(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: DHTExpiration
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
expiration_time: DHTExpiration,
throughput: Optional[float] = None,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[None] * len(uids),
values=[throughput] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers
num_workers=num_workers,
)
def get_remote_module(
dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None,
return_future: bool = False,
) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
@ -70,6 +82,7 @@ def get_remote_module(
)
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
modules = _create_remote_modules_from_infos(await infos_future, p2p)
@ -82,7 +95,7 @@ def get_remote_module(
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
) -> List[Optional[RemoteModuleInfo]]:
if expiration_time is None:
expiration_time = get_dht_time()
@ -115,4 +128,4 @@ def _create_remote_modules_from_infos(
modules.append(src.RemoteTransformerBlock(info, p2p))
else:
modules.append(None)
return modules
return modules

@ -5,7 +5,7 @@ import threading
from typing import Dict, Optional, Sequence, Union
import torch
from hivemind import DHT, BatchTensorDescriptor, MAX_DHT_TIME_DISCREPANCY_SECONDS, get_dht_time
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.dht_handler import DHTHandlerThread
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
@ -43,7 +43,9 @@ class Server(threading.Thread):
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends, dht, update_period, expiration, daemon=True
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
@ -217,6 +219,7 @@ class Server(threading.Thread):
class ModuleAnnouncerThread(threading.Thread):
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
def __init__(
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
):

Loading…
Cancel
Save