@ -10,6 +10,7 @@ import time
from typing import Any , Collection , Dict , List , Optional , Sequence , Union
from weakref import WeakMethod
import dijkstar
import numpy as np
from hivemind import DHT , P2P , MSGPackSerializer , PeerID
from hivemind . dht . node import Blacklist
@ -23,6 +24,8 @@ from petals.client.routing.spending_policy import NoSpendingPolicy
from petals . constants import PUBLIC_INITIAL_PEERS
from petals . data_structures import ModuleUID , RemoteSpanInfo , ServerState
from petals . server . handler import TransformerConnectionHandler
from petals . utils . ping import PingAggregator
from petals . utils . random import sample_up_to
logger = get_logger ( __name__ )
@ -33,6 +36,7 @@ class SequenceManagerConfig:
dht_prefix : Optional [ str ] = None # a prefix for all dht keys that correspond to this model (default: model name)
daemon_startup_timeout : int = 60 # timeout for the libp2p daemon connecting to initial peers
show_route : Union [ str , bool ] = " inference " # show chosen route through servers. one of [False, "inference", True]
allowed_servers : Optional [ Collection [ Union [ PeerID , str ] ] ] = None # if defined, send requests only to these servers
use_server_to_server : bool = True # Use direct server-to-server communication
@ -43,7 +47,10 @@ class SequenceManagerConfig:
min_backoff : float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff : float = 60 # limit maximal sleep time between retries to this value
ban_timeout : float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter : Optional [ str ] = None
active_adapter : Optional [ str ] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged : int = 5 # max servers to ping from each sequence side, per update
ping_timeout : float = 2 # max time to wait for pings, per update
@dataclasses.dataclass
@ -79,7 +86,6 @@ class RemoteSequenceManager:
* ,
dht : Optional [ DHT ] = None ,
state : Optional [ SequenceManagerState ] = None ,
active_adapter : Optional [ str ] = None ,
) :
assert config . initial_peers or dht is not None , " Please specify `config.initial_peers` or `dht` "
assert config . dht_prefix , " Could not find dht_prefix in config, please create model with dht_prefix=... "
@ -94,7 +100,7 @@ class RemoteSequenceManager:
dht = DHT (
initial_peers = config . initial_peers ,
client_mode = True ,
num_workers = config . num_hidden_layers ,
num_workers = 32 ,
startup_timeout = config . daemon_startup_timeout ,
start = True ,
)
@ -109,25 +115,25 @@ class RemoteSequenceManager:
self . _thread_start_lock = threading . Lock ( )
self . policy = NoSpendingPolicy ( )
self . ping_aggregator = PingAggregator ( dht )
if state . banned_peers is None :
state . banned_peers = Blacklist ( base_time = config . ban_timeout , backoff_rate = 2.0 )
if state . sequence_info is None :
state . sequence_info = RemoteSequenceInfo . make_empty ( block_uids )
if state . sequence_info . last_updated_time is None :
# Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
# in the first _update() instead of the latest ones. This makes the first .update() faster.
petals . dht_utils . get_remote_module_infos (
self . dht , self . block_uids , active_adapter = active_adapter , latest = True , return_future = True
)
self . _need_latest_infos = False
else :
if state . sequence_info . last_updated_time is not None :
assert block_uids == state . sequence_info . block_uids
self . _thread . ready . set ( ) # no need to await the first dht fetch
self . _need_latest_infos = True
def make_sequence (
self , start_index : int = 0 , end_index : Optional [ int ] = None , * , mode : str
self ,
start_index : int = 0 ,
end_index : Optional [ int ] = None ,
* ,
mode : str ,
cache_tokens_needed : Optional [ int ] = None ,
) - > List [ RemoteSpanInfo ] :
"""
Form a sequence of remote servers that collectively serve all consecutive layers
@ -143,6 +149,150 @@ class RemoteSequenceManager:
self . update ( wait = True ) # this will await an existing update or trigger a new one (if not updating)
end_index = end_index if end_index is not None else len ( self )
if mode == " min_latency " :
span_sequence = self . _make_sequence_with_min_latency (
start_index , end_index , cache_tokens_needed = cache_tokens_needed
)
elif mode == " max_throughput " :
span_sequence = self . _make_sequence_with_max_throughput ( start_index , end_index )
else :
raise RuntimeError ( f " Unexpected mode { mode } " )
if self . config . show_route is True or ( mode == " min_latency " and self . config . show_route == " inference " ) :
route_repr = " => " . join (
[ f " { span . start } : { span . end } via … { str ( span . peer_id ) [ - 6 : ] } " for span in span_sequence ]
)
logger . info ( f " Route found: { route_repr } " )
return span_sequence
def _make_sequence_with_min_latency (
self , start_index : int , end_index : int , * , cache_tokens_needed : Optional [ int ]
) - > List [ RemoteSpanInfo ] :
if start_index == end_index :
return [ ]
with self . lock_changes :
missing_blocks = [
block_idx
for block_idx in range ( start_index , end_index )
if not self . state . sequence_info . spans_containing_block [ block_idx ]
]
if missing_blocks :
raise MissingBlocksError ( missing_blocks )
server_infos = {
span . peer_id : span . server_info
for block_idx in range ( start_index , end_index )
for span in self . state . sequence_info . spans_containing_block [ block_idx ]
}
graph = self . _build_inference_graph ( start_index , end_index , cache_tokens_needed = cache_tokens_needed )
path = dijkstar . find_path ( graph , " start " , " end " )
logger . debug ( f " Path info: { path } " )
if start_index == 0 and end_index == len ( self ) :
logger . debug ( f " Expected speed: { 1 / path . total_cost : .1f } steps/sec " )
span_sequence = [ ]
for peer_id , block_idx in path . nodes [ 1 : - 1 ] :
if not span_sequence or span_sequence [ - 1 ] . peer_id != peer_id :
span_sequence . append ( RemoteSpanInfo ( peer_id , block_idx , block_idx , server_infos [ peer_id ] ) )
else :
span_sequence [ - 1 ] . end = block_idx
# Remove empty spans that can appear if we don't force to go to the end of each server and network delay
# don't follow triangle inequality (delay(A, B) + delay(B, C) < delay(A, C)) due to measurement errors
span_sequence = [ span for span in span_sequence if span . length > 0 ]
return span_sequence
def _build_inference_graph (
self ,
start_index : int ,
end_index : int ,
* ,
cache_tokens_needed : Optional [ int ] ,
overhead_coeff : float = 1.82 , # Backend overhead (empirically measured)
overhead_delay : float = 0.018 , # Serialization overhead (empirically measured)
default_inference_rps : float = 300 , # If inference RPS unknown
alloc_delay : float = 10 , # If not enough cache left, we penalize the edge
) - > dijkstar . Graph :
missing_blocks = [
block_idx
for block_idx in range ( start_index , end_index )
if not self . state . sequence_info . spans_containing_block [ block_idx ]
]
if missing_blocks :
raise MissingBlocksError ( missing_blocks )
client_server_rtts = self . ping_aggregator . to_dict ( )
graph = dijkstar . Graph ( )
# Clent -> server network delays
for span in self . state . sequence_info . spans_containing_block [ start_index ] :
delay = self . _rtt_to_delay ( client_server_rtts . get ( span . peer_id ) )
delay + = overhead_delay
if not self . _has_cache_for ( span , cache_tokens_needed ) :
delay + = alloc_delay
graph . add_edge ( " start " , ( span . peer_id , start_index ) , delay )
# Server -> client network delays
for span in self . state . sequence_info . spans_containing_block [ end_index - 1 ] :
delay = self . _rtt_to_delay ( client_server_rtts . get ( span . peer_id ) )
graph . add_edge ( ( span . peer_id , end_index ) , " end " , delay )
# Server -> server network delays
for block_idx in range ( start_index + 1 , end_index ) :
for cur_span in self . state . sequence_info . spans_containing_block [ block_idx - 1 ] :
if cur_span . end != block_idx :
# If we choose a server, we force to go to the end of it before switching to a new one
# to avoid O(N^2) graphs for N servers
continue
for next_span in self . state . sequence_info . spans_containing_block [ block_idx ] :
rtt = None
if cur_span . server_info . next_pings is not None :
rtt = cur_span . server_info . next_pings . get ( next_span . peer_id . to_base58 ( ) )
delay = self . _rtt_to_delay ( rtt )
delay + = overhead_delay
if not self . _has_cache_for ( next_span , cache_tokens_needed ) :
delay + = alloc_delay
graph . add_edge ( ( cur_span . peer_id , block_idx ) , ( next_span . peer_id , block_idx ) , delay )
# Compute delays
for span in self . state . sequence_info . spans_by_priority :
for block_idx in range ( max ( span . start , start_index ) , min ( span . end , end_index ) ) :
inference_rps = span . server_info . inference_rps
if inference_rps is None :
inference_rps = default_inference_rps
graph . add_edge ( ( span . peer_id , block_idx ) , ( span . peer_id , block_idx + 1 ) , overhead_coeff / inference_rps )
return graph
@staticmethod
def _rtt_to_delay (
rtt : float ,
* ,
default_delay : float = 0.15 , # If network delay unknown
max_delay : float = 5 , # If unreachable, we don't want to discard the edge completely
) - > float :
if rtt is None :
return default_delay
return min ( rtt / 2 , max_delay )
@staticmethod
def _has_cache_for ( span : RemoteSpanInfo , cache_tokens_needed : Optional [ int ] = None ) - > bool :
if cache_tokens_needed is None or span . server_info . cache_tokens_left is None :
return True
# Here, `span` contains all blocks hosted by a server - but we won't necessarily run all of them through
# this particular server in our path. It is difficult to estimate how many blocks we'll use at this stage,
# so we assume that we'll use all of them (the worst case for the cache size) and get a pessimistic estimate.
# This is okay since false positives are more costly than false negatives here.
return cache_tokens_needed * 2 * span . length < = span . server_info . cache_tokens_left
def _make_sequence_with_max_throughput ( self , start_index : int , end_index : int ) - > List [ RemoteSpanInfo ] :
span_sequence = [ ]
current_index = start_index
while current_index < end_index :
@ -150,20 +300,12 @@ class RemoteSequenceManager:
if not candidate_spans :
raise MissingBlocksError ( current_index )
if mode == " max_throughput " :
span_weights = np . array ( [ span . server_info . throughput for span in candidate_spans ] , dtype = np . float64 )
elif mode == " min_latency " :
span_weights = np . array ( [ span . end - current_index for span in candidate_spans ] , dtype = np . float64 )
else :
raise RuntimeError ( f " Unexpected mode { mode } " )
span_weights = np . array ( [ span . server_info . throughput for span in candidate_spans ] , dtype = np . float64 )
chosen_span = np . random . choice ( candidate_spans , p = span_weights / span_weights . sum ( ) )
assert chosen_span . start < = current_index < chosen_span . end
span_sequence . append ( dataclasses . replace ( chosen_span , start = current_index ) )
current_index = chosen_span . end
route_repr = " => " . join ( [ f " { span . start } : { span . end } via … { str ( span . peer_id ) [ - 6 : ] } " for span in span_sequence ] )
logger . debug ( f " Route found: { route_repr } " )
return span_sequence
def __getitem__ ( self , ix : Union [ int , slice ] ) - > RemoteSequenceManager :
@ -182,10 +324,10 @@ class RemoteSequenceManager:
def _update ( self ) :
""" Perform an immediate and synchronous refresh, may take time """
new_block_infos = petals . dht_utils . get_remote_module_infos (
self . dht , self . block_uids , active_adapter = self . config . active_adapter , latest = self . _need_latest_infos
self . dht , self . block_uids , active_adapter = self . config . active_adapter , latest = True
)
self . _need_latest_infos = True # All future _update() should use latest infos
for block_info in new_block_infos :
if not block_info :
@ -217,6 +359,14 @@ class RemoteSequenceManager:
with self . lock_changes :
self . state . sequence_info . update_ ( new_block_infos )
first_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ 0 ] ]
last_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ - 1 ] ]
pinged_servers = set ( sample_up_to ( first_servers , self . config . max_pinged ) )
pinged_servers | = set ( sample_up_to ( last_servers , self . config . max_pinged ) )
self . ping_aggregator . ping ( list ( pinged_servers ) , wait_timeout = self . config . ping_timeout )
self . ready . set ( )
def on_request_failure ( self , peer_id : Optional [ PeerID ] ) :