@ -1,54 +1,23 @@
from dataclasses import dataclass
from typing import Dict , List , Optional , Tuple
from typing import Dict , List
import numpy as np
from hivemind import PeerID , get_logger
from petals . data_structures import RemoteModuleInfo , ServerState
__all__ = [ " choose_best_blocks " , " should_choose_other_blocks " ]
from petals . data_structures import RemoteModuleInfo , RemoteSpanInfo , ServerState
from petals . utils . dht import compute_spans
logger = get_logger ( __name__ )
@dataclass
class Span :
start : int
end : int
throughput : float
state : ServerState
@property
def length ( self ) :
return self . end - self . start
def move_to ( self , new_start : int ) - > None :
self . start , self . end = new_start , new_start + self . length
def compute_spans ( module_infos : List [ Optional [ RemoteModuleInfo ] ] ) - > Tuple [ Dict [ PeerID , Span ] , np . ndarray ] :
spans = { }
throughputs = np . zeros ( len ( module_infos ) )
for block , module in enumerate ( module_infos ) :
if module is None :
continue
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
for peer_id , server in sorted ( module . servers . items ( ) ) :
if server . state == ServerState . OFFLINE :
continue
def compute_throughputs ( spans : Dict [ PeerID , RemoteSpanInfo ] , * , total_blocks : int ) - > np . ndarray :
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
if peer_id in spans :
spans [ peer_id ] . start = min ( spans [ peer_id ] . start , block )
spans [ peer_id ] . end = max ( spans [ peer_id ] . start , block + 1 )
else :
spans [ peer_id ] = Span ( start = block , end = block + 1 , throughput = server . throughput , state = server . state )
throughputs [ block ] + = server . throughput
return spans , throughputs
throughputs = np . zeros ( total_blocks )
for span in sorted ( spans . values ( ) , key = lambda span : span . peer_id ) :
throughputs [ span . start : span . end ] + = span . throughput
return throughputs
def _choose_best_start ( throughputs : np . ndarray , num_blocks : int ) - > int :
@ -56,19 +25,26 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
return min ( options ) [ - 1 ]
def choose_best_blocks ( num_blocks : int , module_infos : List [ Optional [ RemoteModuleInfo ] ] ) - > List [ int ] :
_ , throughputs = compute_spans ( module_infos )
def choose_best_blocks ( num_blocks : int , module_infos : List [ RemoteModuleInfo ] ) - > List [ int ] :
spans = compute_spans ( module_infos , min_state = ServerState . JOINING )
throughputs = compute_throughputs ( spans , total_blocks = len ( module_infos ) )
start = _choose_best_start ( throughputs , num_blocks )
return list ( range ( start , start + num_blocks ) )
def _move_span ( span : RemoteSpanInfo , new_start : int ) :
span . start , span . end = new_start , new_start + span . length
def should_choose_other_blocks (
local_peer_id : PeerID , module_infos : List [ Optional [ RemoteModuleInfo ] ] , balance_quality : float
local_peer_id : PeerID , module_infos : List [ RemoteModuleInfo] , balance_quality : float
) - > bool :
if balance_quality > 1.0 :
return True # Forces rebalancing on each check (may be used for debugging purposes)
spans , throughputs = compute_spans ( module_infos )
spans = compute_spans ( module_infos , min_state = ServerState . JOINING )
throughputs = compute_throughputs ( spans , total_blocks = len ( module_infos ) )
initial_throughput = throughputs . min ( )
eps = 1e-3
@ -88,7 +64,7 @@ def should_choose_other_blocks(
return False # This server is on its best place already
throughputs [ local_span . start : local_span . end ] + = local_span . throughput * eps
local_span. move_to ( new_start )
_move_span( local_span , new_start )
throughputs [ local_span . start : local_span . end ] + = local_span . throughput
moved = True
@ -105,7 +81,7 @@ def should_choose_other_blocks(
throughputs [ span . start : span . end ] + = span . throughput * eps
if span . start != new_start :
span. move_to ( new_start )
_move_span( span , new_start )
moved = True
throughputs [ span . start : span . end ] + = span . throughput