@ -50,7 +50,7 @@ class SequenceManagerConfig:
ban_timeout : float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
ban_timeout : float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter : Optional [ str ] = None # name of active LoRA adapter (usually, Hugging Face repo)
active_adapter : Optional [ str ] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged : int = 5 # max servers to ping from each sequence side, per update
max_pinged : int = 3 # max servers to ping from each sequence side, per update
ping_timeout : float = 2 # max time to wait for pings, per update
ping_timeout : float = 2 # max time to wait for pings, per update
@ -293,6 +293,8 @@ class RemoteSequenceManager:
return cache_tokens_needed * 2 * span . length < = span . server_info . cache_tokens_left
return cache_tokens_needed * 2 * span . length < = span . server_info . cache_tokens_left
def _make_sequence_with_max_throughput ( self , start_index : int , end_index : int ) - > List [ RemoteSpanInfo ] :
def _make_sequence_with_max_throughput ( self , start_index : int , end_index : int ) - > List [ RemoteSpanInfo ] :
client_server_rtts = self . ping_aggregator . to_dict ( )
span_sequence = [ ]
span_sequence = [ ]
current_index = start_index
current_index = start_index
while current_index < end_index :
while current_index < end_index :
@ -300,7 +302,13 @@ class RemoteSequenceManager:
if not candidate_spans :
if not candidate_spans :
raise MissingBlocksError ( current_index )
raise MissingBlocksError ( current_index )
span_weights = np . array ( [ span . server_info . throughput for span in candidate_spans ] , dtype = np . float64 )
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np . array (
[ span . length if client_server_rtts . get ( span . peer_id ) != np . inf else eps for span in candidate_spans ] ,
dtype = np . float64 ,
)
chosen_span = np . random . choice ( candidate_spans , p = span_weights / span_weights . sum ( ) )
chosen_span = np . random . choice ( candidate_spans , p = span_weights / span_weights . sum ( ) )
assert chosen_span . start < = current_index < chosen_span . end
assert chosen_span . start < = current_index < chosen_span . end
@ -361,9 +369,13 @@ class RemoteSequenceManager:
self . state . sequence_info . update_ ( new_block_infos )
self . state . sequence_info . update_ ( new_block_infos )
first_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ 0 ] ]
first_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ 0 ] ]
middle_servers = [
span . peer_id for spans in self . state . sequence_info . spans_containing_block [ 1 : - 1 ] for span in spans
]
last_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ - 1 ] ]
last_servers = [ span . peer_id for span in self . state . sequence_info . spans_containing_block [ - 1 ] ]
pinged_servers = set ( sample_up_to ( first_servers , self . config . max_pinged ) )
pinged_servers = set ( sample_up_to ( first_servers , self . config . max_pinged ) )
pinged_servers = set ( sample_up_to ( middle_servers , self . config . max_pinged ) )
pinged_servers | = set ( sample_up_to ( last_servers , self . config . max_pinged ) )
pinged_servers | = set ( sample_up_to ( last_servers , self . config . max_pinged ) )
self . ping_aggregator . ping ( list ( pinged_servers ) , wait_timeout = self . config . ping_timeout )
self . ping_aggregator . ping ( list ( pinged_servers ) , wait_timeout = self . config . ping_timeout )