Add `blocked_servers` argument (#462)

Should be used as:

```python
model = AutoDistributedModelForCausalLM(model_name, blocked_servers=[peer_id1, peer_id2])
```
pull/445/head^2
Alexander Borzunov 9 months ago committed by GitHub
parent 722c4dc496
commit 329f7d31e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ import logging
import random
import threading
import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod
import dijkstar
@ -38,6 +38,7 @@ class SequenceManagerConfig:
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection
@ -116,6 +117,9 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None:
@ -128,6 +132,23 @@ class RemoteSequenceManager:
self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence(
self,
start_index: int = 0,
@ -341,13 +362,13 @@ class RemoteSequenceManager:
if not block_info:
continue
# Apply whitelist, if defined
if self.config.allowed_servers is not None:
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
}
# Apply allow and block lists
block_info.servers = {
peer_id: server_info
for peer_id, server_info in block_info.servers.items()
if (self.allowed_servers is None or peer_id in self.allowed_servers)
and (self.blocked_servers is None or peer_id not in self.blocked_servers)
}
# Remove temporarily banned peers, unless there are no peers left
valid_servers = {

@ -43,7 +43,7 @@ def test_remote_sequential():
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2)
assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
# test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]

Loading…
Cancel
Save