List[Optional[RemoteModuleInfo]] -> List[RemoteModuleInfo] everywhere

pull/510/head
Aleksandr Borzunov 9 months ago
parent 2608b5cf8d
commit dcd4641b82

@ -47,50 +47,38 @@ class RemoteSequenceInfo:
def __len__(self):
return len(self.block_uids)
def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
def update_(self, new_block_infos: List[RemoteModuleInfo]):
assert len(new_block_infos) == len(self.block_uids)
for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
if info is None:
logger.debug(f"Found no block info for block {uid}")
continue
if not isinstance(info, RemoteModuleInfo):
logger.warning(f"Unexpected dht entry type for {uid}: {info}")
continue
if not info.servers:
logger.debug(f"Found no active peers for block {uid}")
continue
if info.uid != uid:
logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
continue
assert uid == info.uid, f"The DHT entry for {uid} actually points to {info.uid}"
self.block_infos[block_index].servers = info.servers
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
self.last_updated_time = time.perf_counter()
@staticmethod
def compute_spans(block_infos: Sequence[Optional[RemoteModuleInfo]]):
def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
block_offset = parse_uid(block_infos[0].uid)[1] if block_infos else 0
num_blocks = len(block_infos)
closed_spans = []
active_spans = {}
for block_index, info in enumerate(block_infos):
if info is not None:
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
if server_info.start_block is not None and server_info.end_block is not None:
active_spans[peer_id].start = max(server_info.start_block - block_offset, 0)
active_spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
else: # peer_id in active_spans
active_spans[peer_id].end = max(active_spans[peer_id].end, block_index + 1)
for peer_id, server_info in info.servers.items():
if server_info.state != ServerState.ONLINE:
continue
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
server_info=server_info,
)
if server_info.start_block is not None and server_info.end_block is not None:
active_spans[peer_id].start = max(server_info.start_block - block_offset, 0)
active_spans[peer_id].end = min(server_info.end_block - block_offset, num_blocks)
else: # peer_id in active_spans
active_spans[peer_id].end = max(active_spans[peer_id].end, block_index + 1)
for peer_id in list(active_spans.keys()):
if (

@ -346,9 +346,6 @@ class RemoteSequenceManager:
)
for block_info in new_block_infos:
if not block_info:
continue
# Apply allow and block lists
block_info.servers = {
peer_id: server_info

@ -26,16 +26,13 @@ class Span:
self.start, self.end = new_start, new_start + self.length
def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
def compute_spans(module_infos: List[RemoteModuleInfo]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
block_offset = parse_uid(module_infos[0].uid)[1] if module_infos else 0
num_blocks = len(module_infos)
spans = {}
throughputs = np.zeros(num_blocks)
for block, module in enumerate(module_infos):
if module is None:
continue
# We sort servers here to ensure that we get exactly the same throughputs for a given set of servers.
# If the order were not defined, we would get slightly different values due to floating point errors,
# which may cause excess block replacements.
@ -62,14 +59,14 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
return min(options)[-1]
def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
def choose_best_blocks(num_blocks: int, module_infos: List[RemoteModuleInfo]) -> List[int]:
_, throughputs = compute_spans(module_infos)
start = _choose_best_start(throughputs, num_blocks)
return list(range(start, start + num_blocks))
def should_choose_other_blocks(
local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
local_peer_id: PeerID, module_infos: List[RemoteModuleInfo], balance_quality: float
) -> bool:
if balance_quality > 1.0:
return True # Forces rebalancing on each check (may be used for debugging purposes)

@ -759,12 +759,11 @@ class ModuleAnnouncerThread(threading.Thread):
def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
middle_servers = {peer_id for info in module_infos[:-1] for peer_id in info.servers}
pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
pinged_servers.discard(self.dht.peer_id)
if module_infos[-1] is not None:
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
# Sample servers hosting the block after the last one (most likely continuations) separately
pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
self.ping_aggregator.ping(list(pinged_servers))

@ -70,7 +70,7 @@ def get_remote_module_infos(
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
) -> Union[List[RemoteModuleInfo], MPFuture]:
return dht.run_coroutine(
partial(
_get_remote_module_infos,
@ -90,7 +90,7 @@ async def _get_remote_module_infos(
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
) -> List[RemoteModuleInfo]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
expiration_time = math.inf
@ -99,14 +99,14 @@ async def _get_remote_module_infos(
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
for i, uid in enumerate(uids):
metadata = found[uid]
modules = [RemoteModuleInfo(uid=uid, servers={}) for uid in uids]
for module_info in modules:
metadata = found[module_info.uid]
if metadata is None or not isinstance(metadata.value, dict):
if metadata is not None:
logger.warning(f"Incorrect metadata for {uid}: {metadata}")
logger.warning(f"Incorrect metadata for {module_info.uid}: {metadata}")
continue
servers = {}
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
@ -116,9 +116,7 @@ async def _get_remote_module_infos(
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
module_info.servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
if servers:
modules[i] = RemoteModuleInfo(uid, servers)
logger.warning(f"Incorrect peer entry for uid={module_info.uid}, peer_id={peer_id}: {e}")
return modules

Loading…
Cancel
Save