Add `allowed_servers`, `max_retries` options to the client, improve logs (#235)

pull/251/head
Alexander Borzunov 1 year ago committed by GitHub
parent 3c523ab0d2
commit 9954cb84fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -307,10 +307,11 @@ class InferenceSession:
except Exception as e:
if span is not None:
self._sequence_manager.on_request_failure(span.peer_id)
if attempt_no + 1 == self._sequence_manager.max_retries:
raise
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
f"Caught exception when running inference via {span} (retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
time.sleep(delay)

@ -1,6 +1,6 @@
import os
from contextlib import contextmanager
from typing import List, Optional, Union
from typing import Collection, List, Optional, Union
import hivemind
import torch
@ -35,6 +35,10 @@ class DistributedBloomConfig(BloomConfig):
daemon_startup_timeout: int = 30
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
request_timeout: int = 30 # a number of seconds for waiting result from each node
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
allowed_servers: Optional[
Collection[Union[str, hivemind.PeerID]]
] = None # if defined, send requests only to these servers
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
@ -112,7 +116,11 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
)
)
assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
self.h = RemoteSequential(config, dht, config.dht_prefix, request_timeout=config.request_timeout)
self.h = RemoteSequential(
config,
dht,
config.dht_prefix,
)
# Forbid accumulate grads for embeddings and layernorm
self.set_requires_grad(False)

@ -41,7 +41,16 @@ class RemoteSequential(nn.Module):
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
if sequence_manager is None:
logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs)
self.sequence_manager = RemoteSequenceManager(
dht,
block_uids,
self.p2p,
request_timeout=config.request_timeout,
max_retries=config.max_retries,
allowed_servers=config.allowed_servers,
start=True,
**kwargs,
)
self.is_subsequence = False
else:
logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")

@ -6,7 +6,7 @@ import logging
import random
import threading
import time
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from weakref import WeakMethod
import numpy as np
@ -40,9 +40,10 @@ class RemoteSequenceManager:
:param update_period: by default, refresh DHT information once in this many seconds
:param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
:param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
:param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
:param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
:param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
:param allowed_servers: if defined, send requests only to these servers
:param start: start the background thread (see the note below). If false, you will need to start it manually.
:note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
running redundant sequence managers for the same set of layers.
@ -56,21 +57,30 @@ class RemoteSequenceManager:
p2p: P2P,
update_period: float = 30,
request_timeout: float = 30,
max_retries: Optional[int] = None,
min_backoff: float = 1,
ban_timeout: float = 15,
sequence_info: Optional[RemoteSequenceInfo] = None,
rpc_info: Optional[dict] = None,
allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
banned_peers: Optional[Blacklist] = None,
*, # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
start: bool,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.request_timeout, self.ban_timeout, self.min_backoff = request_timeout, ban_timeout, min_backoff
self.request_timeout, self.max_retries = request_timeout, max_retries
self.ban_timeout, self.min_backoff = ban_timeout, min_backoff
self.lock_changes = threading.Lock()
self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
self.policy = NoSpendingPolicy()
self._rpc_info = rpc_info
if allowed_servers is not None:
allowed_servers = {
PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
}
self.allowed_servers = allowed_servers
self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
if sequence_info is None:
@ -148,6 +158,7 @@ class RemoteSequenceManager:
min_backoff=self.min_backoff,
sequence_info=self.sequence_info[ix],
rpc_info=self._rpc_info,
allowed_servers=self.allowed_servers,
banned_peers=self.banned_peers,
start=True,
)
@ -169,6 +180,16 @@ class RemoteSequenceManager:
for block_info in new_block_infos:
if not block_info:
continue
# Apply whitelist, if defined
if self.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.allowed_servers
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
@ -260,6 +281,8 @@ class RemoteSequenceManager:
except Exception as e:
if peer_id is not None and not isinstance(e, P2PHandlerError):
self.on_request_failure(peer_id)
if attempt_no + 1 == self.max_retries:
raise
delay = self.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when gathering information from peer {peer_id} "

@ -95,10 +95,11 @@ async def sequential_forward(
except Exception as e:
if span is not None:
sequence_manager.on_request_failure(span.peer_id)
if attempt_no + 1 == sequence_manager.max_retries:
raise
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running forward from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
f"Caught exception when running forward via {span} (retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
await asyncio.sleep(delay)
@ -172,10 +173,11 @@ async def sequential_backward(
except Exception as e:
if span is not None:
sequence_manager.on_request_failure(span.peer_id)
if attempt_no + 1 == sequence_manager.max_retries:
raise
delay = sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
f"(retry in {delay:.0f} sec): {repr(e)}"
f"Caught exception when running backward via {span} (retry in {delay:.0f} sec): {repr(e)}"
)
maybe_log_traceback(e)
await asyncio.sleep(delay)

@ -16,6 +16,7 @@ class Span:
start: int
end: int
throughput: float
state: ServerState
@property
def length(self):
@ -43,7 +44,7 @@ def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
throughputs[block] += server.throughput

Loading…
Cancel
Save