support hosting multiple instances of the same block

multiple-experts
justheuristic 2 years ago
parent 14b6d04b0f
commit 2eb47cbedd

@ -49,7 +49,7 @@ Then open a python notebook or console and run:
```python
import torch
import hivemind
from src.client.remote_block import get_remote_module
from src import get_remote_module
dht = hivemind.DHT(
initial_peers=["/ip4/127.0.0.1/COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS"],

@ -1 +1,3 @@
from .bloom import *
from .dht_utils import get_remote_module, declare_active_modules
from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession

@ -1 +1 @@
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession

@ -1,29 +0,0 @@
from collections import defaultdict
from typing import Sequence
import torch
from hivemind import DHT
from torch import nn
from src import DistributedBloomConfig
from src.server.backend import MAX_LENGTH
class RemoteInferenceChain(nn.Module):
"""An auxiliary class that manages distributed inference in a chain of one or more remote transformer modules"""
def __init__(self, dht: DHT, config: DistributedBloomConfig, block_names: Sequence[str]):
super().__init__()
self.dht = dht
self.config, self.block_names = config, block_names
self.block_caches = {name: torch.zeros(1, MAX_LENGTH, config.hidden_size) for name in block_names}
self.current_position = 0
def step(self, hidden_states: torch.Tensor):
pass
# plan:
# - run inference STUB from a jupyter notebook
# - extend to run actual inference
# - extend to run multiple layers at a time

@ -1,38 +1,45 @@
from __future__ import annotations
import asyncio
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Union
import random
from typing import Any, AsyncIterator, Dict, Optional
import torch
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertInfo as RemoteModuleInfo
from hivemind.moe.expert_uid import ExpertUID
from hivemind.p2p import P2P, PeerID, StubBase
from hivemind.moe.client.expert import RemoteExpertWorker, RemoteExpert
from hivemind.moe.expert_uid import ExpertInfo
from hivemind.p2p import P2P, StubBase
from hivemind.proto import runtime_pb2
from hivemind.utils import DHTExpiration, MPFuture, anext, as_aiter, get_dht_time, nested_flatten
from hivemind.utils import anext, nested_flatten
from src.dht_utils import ModuleUID
from src.data_structures import RemoteModuleInfo
from src.server.handler import TransformerConnectionHandler
class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) #TODO replace this
super().__init__(peer_info, p2p)
@property
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
"""Initialize a new inference session with the specified remote server"""
_ = self.info # create _info manually since the built-in property will not work inside RemoteExpertWorker
return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
class RemoteTransformerBlockInferenceSession:
"""An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
def __init__(self, uid: ExpertUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
def __init__(
self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator
):
self.uid, self.info = uid, info
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
@ -113,55 +120,3 @@ class RemoteTransformerBlockInferenceSession:
def __exit__(self, *exc_details):
self.close()
def get_remote_module(
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
"""
:param uids: find experts with these ids from across the DHT
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: a list of [RemoteTransformerBlock if found else None]
"""
assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=list(uids), expiration_time=expiration_time), return_future
)
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
return _create_remote_modules_from_infos(await infos_future, p2p)
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
return _create_remote_modules_from_infos(infos, p2p)
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration]
) -> List[Optional[RemoteModuleInfo]]:
if expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ExpertUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
experts: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
server_peer_id = found[uid]
if server_peer_id is not None and isinstance(server_peer_id.value, str):
experts[i] = RemoteModuleInfo(uid, PeerID.from_base58(server_peer_id.value))
return experts
def _create_remote_modules_from_infos(
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
) -> List[Optional[RemoteTransformerBlock]]:
experts: List[Optional[RemoteTransformerBlock]] = []
for info in infos:
if info is not None:
experts.append(RemoteTransformerBlock(info, p2p))
else:
experts.append(None)
return experts

@ -0,0 +1,8 @@
from typing import NamedTuple, Collection
from hivemind import PeerID
ModuleUID = str
UID_DELIMITER = '.'
RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])

@ -0,0 +1,118 @@
"""
Utilities for declaring and retrieving active model layers using a shared DHT.
"""
from __future__ import annotations
from functools import partial
from typing import Dict, List, Optional, Sequence, Union
from hivemind.dht import DHT, DHTNode, DHTValue
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, use_hivemind_log_handler, get_logger
from hivemind.p2p import P2P, PeerID
import src
from src.data_structures import RemoteModuleInfo, ModuleUID, UID_DELIMITER
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
def declare_active_modules(
dht: DHT, uids: Sequence[ModuleUID], expiration_time: DHTExpiration, wait: bool = True
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
Declare that your node serves the specified modules; update timestamps if declared previously
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param expiration_time: declated modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
if isinstance(uids, str):
uids = [uids]
if not isinstance(uids, list):
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time), return_future=not wait
)
async def _declare_active_modules(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: DHTExpiration
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[None] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers
)
def get_remote_module(
dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
"""
:param uid_or_uids: find one or more modules with these ids from across the DHT
:param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
:returns: a list of [RemoteTransformerBlock if found else None]
"""
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
)
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
modules = _create_remote_modules_from_infos(await infos_future, p2p)
return modules[0] if single_uid else modules
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
modules = _create_remote_modules_from_infos(infos, p2p)
return modules[0] if single_uid else modules
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
) -> List[Optional[RemoteModuleInfo]]:
if expiration_time is None:
expiration_time = get_dht_time()
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
if metadata is None or not isinstance(metadata.value, dict):
logger.error(f"Incorrect metadata for {uid}: {metadata}")
continue
valid_entries = set()
for maybe_peer_id, _unused_value in metadata.value.items():
try:
valid_entries.add(PeerID.from_base58(maybe_peer_id))
except:
logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
if valid_entries:
modules[i] = RemoteModuleInfo(uid, valid_entries)
return modules
def _create_remote_modules_from_infos(
infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
) -> List[Optional[src.RemoteTransformerBlock]]:
modules: List[Optional[src.RemoteTransformerBlock]] = []
for info in infos:
if info is not None:
modules.append(src.RemoteTransformerBlock(info, p2p))
else:
modules.append(None)
return modules

@ -5,13 +5,14 @@ import threading
from typing import Dict, Optional, Sequence, Union
import torch
from hivemind import DHT, BatchTensorDescriptor
from hivemind import DHT, BatchTensorDescriptor, MAX_DHT_TIME_DISCREPANCY_SECONDS, get_dht_time
from hivemind.moe.server.dht_handler import DHTHandlerThread
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules
from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
from src.server.backend import TransformerBackend
from src.server.cache import MemoryCache
@ -42,7 +43,7 @@ class Server(threading.Thread):
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.dht_handler_thread = ModuleAnnouncerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
if start:
@ -212,3 +213,23 @@ class Server(threading.Thread):
self.runtime.shutdown()
logger.info("Server shutdown succesfully")
class ModuleAnnouncerThread(threading.Thread):
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
def __init__(
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
):
super().__init__(**kwargs)
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.module_backends = module_backends
self.dht = dht
self.update_period = update_period
self.expiration = expiration
self.stop = threading.Event()
def run(self) -> None:
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
while not self.stop.wait(self.update_period):
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)

@ -4,7 +4,8 @@ import hivemind
import torch
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock, get_remote_module
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
@ -22,7 +23,7 @@ REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
(remote_block,) = get_remote_module(dht, [BLOCK_UID])
remote_block = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)

Loading…
Cancel
Save