Implement block selection on servers (#20)

pull/22/head
Alexander Borzunov 2 years ago committed by GitHub
parent f055135b08
commit aba43f1308
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -43,7 +43,7 @@ python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6
# - give each server a unique --identity_path (or remote --identity_path arg when debugging)
# - if running multiple servers on the same machine, give each a unique port (last integer in --host_maddrs, 0 means random port)
# - when running over the internet, change --host_maddrs according to https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet
# - each server except first should have --initial_peers pointing to one of pre-existing servers
# - each server except first should have --initial_peers pointing to one of pre-existing servers
```
Then open a python notebook or console and run:
@ -66,7 +66,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
loss.backward()
# test inference, one block
with layer3.begin_inference_session() as sess:
with layer3.inference_session() as sess:
for i in range(10):
res = sess.step(torch.ones(1, 1, 4096))
```

@ -41,6 +41,8 @@ def main():
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--throughput', type=float, default=1.0,
help='Expected server throughput')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report experts to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,

@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
"""A class that interacts with a remote module on a specific server for forward/backward or inference"""
def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) # TODO replace this
peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys()))) # TODO replace this
super().__init__(peer_info, p2p)
@property

@ -1,15 +1,13 @@
from __future__ import annotations
import dataclasses
import threading
from functools import partial
from typing import List, NamedTuple, Optional, Sequence, Tuple
from hivemind import DHT, PeerID
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src.data_structures import ModuleUID, RemoteModuleInfo
from src.dht_utils import _get_remote_module_infos
from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
from src.dht_utils import get_remote_module_infos
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -18,21 +16,20 @@ logger = get_logger(__file__)
Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
@dataclasses.dataclass(frozen=False, init=False) # TODO[borzunov@] eto ne dataclass
class RemoteSequenceInfo:
"""Keeps and updates the meta-information about which peers host which blocks"""
dht: DHT
block_uids: List[ModuleUID, ...]
block_infos: List[Optional[RemoteModuleInfo], ...]
block_uids: List[ModuleUID]
block_infos: List[Optional[RemoteModuleInfo]]
spans_by_priority: List[Span] # sorted from best to worst
spans_containing_block: Tuple[List[Span], ...]
spans_containing_block: Tuple[List[Span]]
lock_changes: threading.Lock
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
self.dht = dht
self.block_uids = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
self.block_infos = [None] * len(self.block_uids)
self.spans_by_priority = []
self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
self.lock_changes = threading.Lock()
@ -48,21 +45,17 @@ class RemoteSequenceInfo:
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
def update_block_infos_(self):
new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
)
new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.warning(f"Found no block info for block {uid}")
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
if not info.peer_ids:
if not info.servers:
logger.warning(f"Found no active peers for block {uid}")
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
if not isinstance(info.peer_ids, set):
logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
self.block_infos[block_index] = info
@staticmethod
@ -70,14 +63,20 @@ class RemoteSequenceInfo:
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
for peer_id in info.peer_ids:
for peer_id, server in info.servers.items():
if server.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
else: # peer_id in active_spans
active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
for peer_id in list(active_spans.keys()):
if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
if (
peer_id not in info.servers or
info.servers[peer_id].state != ServerState.ONLINE or
block_index == len(block_infos) - 1
):
closed_spans.append(active_spans.pop(peer_id))
assert not active_spans

@ -1,8 +1,27 @@
from typing import Collection, NamedTuple
from dataclasses import dataclass
from enum import Enum
from typing import Dict
from hivemind import PeerID
ModuleUID = str
UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
class ServerState(Enum):
OFFLINE = 0
JOINING = 1
ONLINE = 2
@dataclass
class ServerInfo:
state: ServerState
throughput: float
@dataclass
class RemoteModuleInfo:
uid: ModuleUID
servers: Dict[PeerID, ServerInfo]

@ -12,7 +12,7 @@ from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
import src
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
@ -22,7 +22,8 @@ def declare_active_modules(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: DHTExpiration,
throughput: Optional[float] = None,
state: ServerState,
throughput: float,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
@ -30,7 +31,7 @@ def declare_active_modules(
:param uids: a list of module ids to declare
:param wait: if True, awaits for declaration to finish, otherwise runs in background
:param throughput: optionally specify your performance in terms of compute throughput
:param throughput: specify your performance in terms of compute throughput
:param expiration_time: declated modules will be visible for this many seconds
:returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
"""
@ -41,7 +42,13 @@ def declare_active_modules(
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
partial(
_declare_active_modules,
uids=uids,
expiration_time=expiration_time,
state=state,
throughput=throughput,
),
return_future=not wait,
)
@ -51,13 +58,14 @@ async def _declare_active_modules(
node: DHTNode,
uids: List[ModuleUID],
expiration_time: DHTExpiration,
throughput: Optional[float] = None,
state: ServerState,
throughput: float,
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[throughput] * len(uids),
values=[(state.value, throughput)] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
@ -94,6 +102,19 @@ def get_remote_module(
return modules[0] if single_uid else modules
def get_remote_module_infos(
dht: DHT,
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
expiration_time: Optional[DHTExpiration] = None,
) -> List[Optional[RemoteModuleInfo]]:
single_uid = isinstance(uid_or_uids, ModuleUID)
uids = [uid_or_uids] if single_uid else uid_or_uids
infos = dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
)
return infos[0] if single_uid else infos
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
) -> List[Optional[RemoteModuleInfo]]:
@ -109,14 +130,20 @@ async def _get_remote_module_infos(
if metadata is not None:
logger.error(f"Incorrect metadata for {uid}: {metadata}")
continue
valid_entries = set()
for maybe_peer_id, _unused_value in metadata.value.items():
servers = {}
for peer_id, server_info in metadata.value.items():
try:
valid_entries.add(PeerID.from_base58(maybe_peer_id))
except:
logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
if valid_entries:
modules[i] = RemoteModuleInfo(uid, valid_entries)
peer_id = PeerID.from_base58(peer_id)
server_info = server_info.value
if not (isinstance(server_info, tuple) and len(server_info) == 2 and
isinstance(server_info[0], int) and isinstance(server_info[1], float)):
raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
state, throughput = server_info
servers[peer_id] = ServerInfo(ServerState(state), throughput)
except (TypeError, ValueError) as e:
logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
return modules

@ -0,0 +1,18 @@
from typing import List, Optional
from src.data_structures import RemoteModuleInfo, ServerState
def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
throughputs = []
for module in remote_module_infos:
if module is None:
throughputs.append(0)
continue
throughputs.append(sum(server.throughput for server in module.servers.values()
if server.state != ServerState.OFFLINE))
options = [(throughputs[i:i + num_blocks], i)
for i in range(0, len(throughputs) - num_blocks + 1)]
best_start = min(options)[1]
return list(range(best_start, best_start + num_blocks))

@ -13,8 +13,10 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from src import declare_active_modules, BloomConfig
from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from src.dht_utils import get_remote_module_infos
from src.server.backend import TransformerBackend
from src.server.block_selection import choose_best_blocks
from src.server.cache import MemoryCache
from src.server.handler import TransformerConnectionHandler
@ -32,19 +34,26 @@ class Server(threading.Thread):
*,
device: torch.device,
num_connection_handlers: int = 8,
throughput: float,
update_period: float = 30,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
threading.Thread.__init__(self)
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = ModuleAnnouncerThread(
self.module_backends, dht, update_period, expiration, daemon=True
self.module_backends,
dht,
throughput=throughput,
update_period=update_period,
expiration=expiration,
daemon=True,
)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
@ -86,6 +95,7 @@ class Server(threading.Thread):
cls,
prefix: Optional[str],
converted_model_name_or_path: str,
throughput: float,
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: Optional[int] = None,
@ -116,6 +126,9 @@ class Server(threading.Thread):
)
logger.info(f"Automatic dht prefix: {prefix}")
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@ -127,6 +140,10 @@ class Server(threading.Thread):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token
)
if block_indices is not None:
try:
first_block_index, last_block_index = block_indices.split(":")
@ -137,16 +154,22 @@ class Server(threading.Thread):
block_indices = range(first_block_index, last_block_index)
else:
assert num_blocks is not None
block_indices = range(num_blocks) # TODO replace with proper load balancing
uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
block_indices = choose_best_blocks(num_blocks, module_infos)
block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, use_auth_token=use_auth_token
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
declare_active_modules(
dht,
module_uids,
expiration_time=get_dht_time() + expiration,
state=ServerState.JOINING,
throughput=throughput,
)
logger.info(f"Announced that blocks {block_indices} are joining")
# initialize modules
blocks = {}
for block_index in block_indices:
module_uid = f"{prefix}.{block_index}"
for module_uid, block_index in zip(module_uids, block_indices):
block = load_pretrained_block(
converted_model_name_or_path,
block_index,
@ -173,6 +196,7 @@ class Server(threading.Thread):
return cls(
dht,
blocks,
throughput=throughput,
num_connection_handlers=num_handlers,
device=device,
stats_report_interval=stats_report_interval,
@ -209,6 +233,16 @@ class Server(threading.Thread):
Please note that terminating server otherwise (e.g. by killing processes) may result in zombie processes.
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
"""
if self.module_backends:
declare_active_modules(
self.dht,
self.module_backends.keys(),
expiration_time=get_dht_time() + self.expiration,
state=ServerState.OFFLINE,
throughput=self.throughput,
)
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
self.ready.clear()
for process in self.conn_handlers:
@ -230,25 +264,38 @@ class Server(threading.Thread):
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()
logger.info("Server shutdown succesfully")
logger.info("Server shut down succesfully")
class ModuleAnnouncerThread(threading.Thread):
"""Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
def __init__(
self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
self,
module_backends: Dict[str, TransformerBackend],
dht: DHT,
*,
throughput: float,
update_period: float = 30,
expiration: float,
**kwargs
):
super().__init__(**kwargs)
if expiration is None:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.module_backends = module_backends
self.dht = dht
self.throughput = throughput
self.update_period = update_period
self.expiration = expiration
self.stop = threading.Event()
def run(self) -> None:
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
while not self.stop.wait(self.update_period):
declare_active_modules(self.dht, self.module_backends.keys(), get_dht_time() + self.expiration)
while True:
declare_active_modules(
self.dht,
self.module_backends.keys(),
expiration_time=get_dht_time() + self.expiration,
state=ServerState.ONLINE,
throughput=self.throughput,
)
if self.stop.wait(self.update_period):
break

Loading…
Cancel
Save