@ -83,7 +83,7 @@ class Server:
quant_type : Optional [ QuantType ] = None ,
tensor_parallel_devices : Optional [ Sequence [ torch . device ] ] = None ,
skip_reachability_check : bool = False ,
dht_client_mode : Optional [ bool ] = None ,
reachable_via_relay : Optional [ bool ] = None ,
use_relay : bool = True ,
use_auto_relay : bool = True ,
adapters : Sequence [ str ] = ( ) ,
@ -129,20 +129,20 @@ class Server:
for block_index in range ( self . block_config . num_hidden_layers )
]
if dht_client_mode is None :
if reachable_via_relay is None :
is_reachable = check_direct_reachability ( initial_peers = initial_peers , use_relay = False , * * kwargs )
dht_client_mode = is_reachable is False # if could no t check reachability (returns None), run a full peer
logger . info ( f " This server is accessible { ' via relays ' if dht_client_mode else ' directly ' } " )
reachable_via_relay = is_reachable is False # if can' t check reachability (returns None), run a full peer
logger . info ( f " This server is accessible { ' via relays ' if reachable_via_relay else ' directly ' } " )
self . dht = DHT (
initial_peers = initial_peers ,
start = True ,
num_workers = self . block_config . num_hidden_layers ,
use_relay = use_relay ,
use_auto_relay = use_auto_relay ,
client_mode = dht_client_mode ,
client_mode = reachable_via_relay ,
* * kwargs ,
)
self . reachability_protocol = ReachabilityProtocol . attach_to_dht ( self . dht ) if not dht_client_mode else None
self . reachability_protocol = ReachabilityProtocol . attach_to_dht ( self . dht ) if not reachable_via_relay else None
visible_maddrs_str = [ str ( a ) for a in self . dht . get_visible_maddrs ( ) ]
if initial_peers == PUBLIC_INITIAL_PEERS :
@ -227,6 +227,7 @@ class Server:
num_blocks = num_blocks ,
quant_type = quant_type ,
tensor_parallel_devices = self . tensor_parallel_devices ,
reachable_via_relay = reachable_via_relay ,
force_eval = ( throughput == " eval " ) ,
cache_dir = cache_dir ,
)
@ -239,7 +240,7 @@ class Server:
adapters = tuple ( adapters ) ,
torch_dtype = str ( torch_dtype ) . replace ( " torch. " , " " ) ,
quant_type = quant_type . name . lower ( ) ,
using_relay = self . dht . client_mode ,
using_relay = reachable_via_relay ,
* * throughput_info ,
)