@ -9,6 +9,7 @@ import time
from typing import Any , Dict , List , Optional , Sequence , Union
from weakref import WeakMethod
import numpy as np
from hivemind import DHT , P2P , MSGPackSerializer , PeerID
from hivemind . dht . node import Blacklist
from hivemind . moe . client . remote_expert_worker import RemoteExpertWorker
@ -92,12 +93,15 @@ class RemoteSequenceManager:
if await_ready :
self . _thread . ready . wait ( timeout )
def make_sequence ( self , start_index : int = 0 , end_index : Optional [ int ] = None ) - > List [ RemoteSpanInfo ] :
def make_sequence (
self , start_index : int = 0 , end_index : Optional [ int ] = None , mode : str = " random "
) - > List [ RemoteSpanInfo ] :
"""
Form a sequence of remote servers that collectively serve all consecutive layers
: param start_index : optional index of the first module in a sequence , default = the first of block_uids
: param end_index : optional index of the last module ( non - inclusive ) , default = after last of block uids
: param mode : either random or fastest
"""
if not self . is_alive ( ) :
logger . error ( " Using a sequence manager that is not running: it has either crashed or never started " )
@ -110,7 +114,14 @@ class RemoteSequenceManager:
current_index = start_index
while current_index < end_index :
candidate_spans = self . sequence_info . spans_containing_block [ current_index ]
chosen_span = random . choice ( candidate_spans ) # TODO this should be replaced with proper load balancing
if mode == " random " :
chosen_span = random . choice ( candidate_spans ) # TODO this should be replaced with proper load balancing
elif mode == " fastest " :
# note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
span_weights = np . array ( [ span . end - current_index for span in candidate_spans ] , dtype = np . float64 )
chosen_span = np . random . choice ( candidate_spans , p = span_weights / span_weights . sum ( ) )
else :
raise RuntimeError ( f " Unexpected mode { mode } " )
assert chosen_span . start < = current_index < chosen_span . end
span_sequence . append ( RemoteSpanInfo ( start = current_index , end = chosen_span . end , peer_id = chosen_span . peer_id ) )