Add `blocked_servers` argument (#462)

Should be used as:

```python
model = AutoDistributedModelForCausalLM(model_name, blocked_servers=[peer_id1, peer_id2])
```
pull/445/head^2
Alexander Borzunov 9 months ago committed by GitHub
parent 722c4dc496
commit 329f7d31e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ import logging
import random import random
import threading import threading
import time import time
from typing import Any, Collection, Dict, List, Optional, Sequence, Union from typing import Any, Collection, Dict, List, Optional, Sequence, Set, Union
from weakref import WeakMethod from weakref import WeakMethod
import dijkstar import dijkstar
@ -38,6 +38,7 @@ class SequenceManagerConfig:
show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True] show_route: Union[str, bool] = "inference" # show chosen route through servers. one of [False, "inference", True]
allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers allowed_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, send requests only to these servers
blocked_servers: Optional[Collection[Union[PeerID, str]]] = None # if defined, do not use these servers
use_server_to_server: bool = True # Use direct server-to-server communication use_server_to_server: bool = True # Use direct server-to-server communication
connect_timeout: float = 5 # timeout for opening a connection connect_timeout: float = 5 # timeout for opening a connection
@ -116,6 +117,9 @@ class RemoteSequenceManager:
self._thread_start_lock = threading.Lock() self._thread_start_lock = threading.Lock()
self.policy = NoSpendingPolicy() self.policy = NoSpendingPolicy()
self.allowed_servers = self._peer_ids_to_set(config.allowed_servers)
self.blocked_servers = self._peer_ids_to_set(config.blocked_servers)
self.ping_aggregator = PingAggregator(dht) self.ping_aggregator = PingAggregator(dht)
if state.banned_peers is None: if state.banned_peers is None:
@ -128,6 +132,23 @@ class RemoteSequenceManager:
self._thread.ready.set() # no need to await the first dht fetch self._thread.ready.set() # no need to await the first dht fetch
self._need_latest_infos = True self._need_latest_infos = True
@staticmethod
def _peer_ids_to_set(peer_ids: Optional[Collection[Union[PeerID, str]]]) -> Optional[Set[PeerID]]:
if peer_ids is None:
return None
result = set()
for peer_id in peer_ids:
if isinstance(peer_id, PeerID):
result.add(peer_id)
elif isinstance(peer_id, str):
result.add(PeerID.from_base58(peer_id))
else:
raise TypeError(
f"`allowed_servers` and `blocked_servers` have to contain only PeerIDs or strings, but got {type(peer_id)}"
)
return result
def make_sequence( def make_sequence(
self, self,
start_index: int = 0, start_index: int = 0,
@ -341,13 +362,13 @@ class RemoteSequenceManager:
if not block_info: if not block_info:
continue continue
# Apply whitelist, if defined # Apply allow and block lists
if self.config.allowed_servers is not None: block_info.servers = {
block_info.servers = { peer_id: server_info
peer_id: server_info for peer_id, server_info in block_info.servers.items()
for peer_id, server_info in block_info.servers.items() if (self.allowed_servers is None or peer_id in self.allowed_servers)
if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers and (self.blocked_servers is None or peer_id not in self.blocked_servers)
} }
# Remove temporarily banned peers, unless there are no peers left # Remove temporarily banned peers, unless there are no peers left
valid_servers = { valid_servers = {

@ -43,7 +43,7 @@ def test_remote_sequential():
assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3) assert torch.allclose(second_half_outputs, full_outputs, atol=1e-3)
(second_half_outputs * grad_proj).sum().backward() (second_half_outputs * grad_proj).sum().backward()
assert torch.allclose(test_inputs.grad, full_grad, atol=1e-2) assert torch.allclose(test_inputs.grad, full_grad, atol=3e-2)
# test RemoteSequential with lossy compression # test RemoteSequential with lossy compression
block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)] block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]

Loading…
Cancel
Save