Use start_block, end_block if present

pull/510/head
Aleksandr Borzunov 9 months ago
parent 145377c4cc
commit 0f527a0788

@ -79,12 +79,12 @@ class RemoteSequenceInfo:
if peer_id not in active_spans:
active_spans[peer_id] = RemoteSpanInfo(
peer_id=peer_id,
start=block_index,
end=block_index + 1,
start=server_info.get("start_block", block_index),
end=server_info.get("end_block", block_index + 1),
server_info=server_info,
)
else: # peer_id in active_spans
active_spans[peer_id].end = block_index + 1
active_spans[peer_id].end = max(active_spans[peer_id].end, block_index + 1)
for peer_id in list(active_spans.keys()):
if (

@ -55,6 +55,10 @@ class ServerInfo:
cache_tokens_left: Optional[pydantic.conint(ge=0, strict=True)] = None
next_pings: Optional[Dict[str, pydantic.confloat(ge=0, strict=True)]] = None
def get(self, name: str, default: Any = None) -> Any:
value = getattr(self, name)
return value if value is not None else default
def to_tuple(self) -> Tuple[int, float, dict]:
extra_info = dataclasses.asdict(self)
del extra_info["state"], extra_info["throughput"]

@ -42,9 +42,14 @@ def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[
if peer_id in spans:
spans[peer_id].start = min(spans[peer_id].start, block)
spans[peer_id].end = max(spans[peer_id].start, block + 1)
spans[peer_id].end = max(spans[peer_id].end, block + 1)
else:
spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput, state=server.state)
spans[peer_id] = Span(
start=server.get("start_block", block),
end=server.get("end_block", block + 1),
throughput=server.throughput,
state=server.state,
)
throughputs[block] += server.throughput

Loading…
Cancel
Save