Prefer longer servers for fine-tuning, exclude unreachable (#448)

We choose longer servers to minimize the number of hops but leave some randomization to distribute the load. We also exclude servers known to be unreachable.
amd-gpus^2
Alexander Borzunov 9 months ago committed by GitHub
parent 00d48dcbe1
commit 2a150770a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -50,7 +50,7 @@ class SequenceManagerConfig:
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
max_pinged: int = 5 # max servers to ping from each sequence side, per update
max_pinged: int = 3 # max servers to ping from each sequence side, per update
ping_timeout: float = 2 # max time to wait for pings, per update
@ -293,6 +293,8 @@ class RemoteSequenceManager:
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
client_server_rtts = self.ping_aggregator.to_dict()
span_sequence = []
current_index = start_index
while current_index < end_index:
@ -300,7 +302,13 @@ class RemoteSequenceManager:
if not candidate_spans:
raise MissingBlocksError(current_index)
span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
# We choose longer servers to minimize the number of hops but leave some randomization
# to distribute the load. We also exclude servers known to be unreachable.
eps = 1e-6
span_weights = np.array(
[span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
dtype=np.float64,
)
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
assert chosen_span.start <= current_index < chosen_span.end
@ -361,9 +369,13 @@ class RemoteSequenceManager:
self.state.sequence_info.update_(new_block_infos)
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
middle_servers = [
span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
]
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)

Loading…
Cancel
Save