@ -1,17 +1,15 @@
import dataclasses
import time
from typing import Iterable , List , Optional , Sequence, Tuple, Type , TypeVar
from typing import Iterable , List , Optional , Tuple
from hivemind import get_logger
from petals . data_structures import ModuleUID , RemoteModuleInfo , RemoteSpanInfo , ServerState
from petals . utils . dht import compute_spans
logger = get_logger ( __name__ )
T = TypeVar ( " T " )
@dataclasses.dataclass
class RemoteSequenceInfo :
"""
@ -30,7 +28,7 @@ class RemoteSequenceInfo:
last_updated_time : Optional [ float ]
@classmethod
def make_empty ( cls : Type [ T ] , block_uids : Iterable [ ModuleUID ] ) - > T :
def make_empty ( cls , block_uids : Iterable [ ModuleUID ] ) - > " RemoteSequenceInfo " :
block_uids = tuple ( block_uids )
empty_block_infos = tuple ( RemoteModuleInfo ( uid , { } ) for uid in block_uids )
empty_spans = tuple ( [ ] for _ in range ( len ( block_uids ) ) )
@ -39,7 +37,7 @@ class RemoteSequenceInfo:
def __getitem__ ( self , ix : slice ) :
assert isinstance ( ix , slice )
block_uids , block_infos = self . block_uids [ ix ] , self . block_infos [ ix ]
spans_by_priority , spans_containing_block = self . compute _spans( block_infos )
spans_by_priority , spans_containing_block = self . _sort _spans( block_infos )
return RemoteSequenceInfo (
block_uids , block_infos , spans_by_priority , spans_containing_block , self . last_updated_time
)
@ -47,60 +45,23 @@ class RemoteSequenceInfo:
def __len__ ( self ) :
return len ( self . block_uids )
def update_ ( self , new_block_infos : List [ Optional[ RemoteModuleInfo] ] ) :
def update_ ( self , new_block_infos : List [ RemoteModuleInfo] ) :
assert len ( new_block_infos ) == len ( self . block_uids )
for block_index , ( uid , info ) in enumerate ( zip ( self . block_uids , new_block_infos ) ) :
if info is None :
logger . debug ( f " Found no block info for block { uid } " )
continue
if not isinstance ( info , RemoteModuleInfo ) :
logger . warning ( f " Unexpected dht entry type for { uid } : { info } " )
continue
if not info . servers :
logger . debug ( f " Found no active peers for block { uid } " )
continue
if info . uid != uid :
logger . warning ( f " The DHT entry for { uid } actually points to { info . uid } " )
continue
assert uid == info . uid , f " The DHT entry for { uid } actually points to { info . uid } "
self . block_infos [ block_index ] . servers = info . servers
self . spans_by_priority , self . spans_containing_block = self . compute _spans( self . block_infos )
self . spans_by_priority , self . spans_containing_block = self . _sort_spans ( self . block_infos )
self . last_updated_time = time . perf_counter ( )
@staticmethod
def compute_spans ( block_infos : Sequence [ RemoteModuleInfo ] ) :
closed_spans = [ ]
active_spans = { }
for block_index , info in enumerate ( block_infos ) :
if info is not None :
for peer_id , server_info in info . servers . items ( ) :
if server_info . state != ServerState . ONLINE :
continue
if peer_id not in active_spans :
active_spans [ peer_id ] = RemoteSpanInfo (
peer_id = peer_id ,
start = block_index ,
end = block_index + 1 ,
server_info = server_info ,
)
else : # peer_id in active_spans
active_spans [ peer_id ] . end = block_index + 1
for peer_id in list ( active_spans . keys ( ) ) :
if (
info is None
or peer_id not in info . servers
or info . servers [ peer_id ] . state != ServerState . ONLINE
or block_index == len ( block_infos ) - 1
) :
closed_spans . append ( active_spans . pop ( peer_id ) )
assert not active_spans , f " spans: { active_spans } "
closed_spans . sort ( key = lambda span : span . length , reverse = True )
def _sort_spans ( block_infos : List [ RemoteModuleInfo ] ) :
spans_by_priority = list ( compute_spans ( block_infos , min_state = ServerState . ONLINE ) . values ( ) )
spans_by_priority . sort ( key = lambda span : span . length , reverse = True )
spans_containing_block = tuple ( list ( ) for _ in range ( len ( block_infos ) ) )
for span in closed_ spans:
spans_containing_block = tuple ( [ ] for _ in range ( len ( block_infos ) ) )
for span in spans_by_priority :
for block_index in range ( span . start , span . end ) :
spans_containing_block [ block_index ] . append ( span )
return closed_ spans, spans_containing_block
return spans_by_priority , spans_containing_block