@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union
import numpy as np
import numpy as np
import psutil
import psutil
import requests
import torch
import torch
from hivemind import DHT , MAX_DHT_TIME_DISCREPANCY_SECONDS , BatchTensorDescriptor , get_dht_time
from hivemind import DHT , MAX_DHT_TIME_DISCREPANCY_SECONDS , BatchTensorDescriptor , get_dht_time
from hivemind . moe . server . layers import add_custom_models_from_file
from hivemind . moe . server . layers import add_custom_models_from_file
@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend
from petals . server . block_utils import get_block_size
from petals . server . block_utils import get_block_size
from petals . server . handler import TransformerConnectionHandler
from petals . server . handler import TransformerConnectionHandler
from petals . server . memory_cache import MemoryCache
from petals . server . memory_cache import MemoryCache
from petals . server . reachability import check_reachability
from petals . server . throughput import get_host_throughput
from petals . server . throughput import get_host_throughput
from petals . utils . convert_block import check_device_balance , convert_block
from petals . utils . convert_block import check_device_balance , convert_block
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
@ -78,6 +78,8 @@ class Server:
load_in_8bit : Optional [ bool ] = None ,
load_in_8bit : Optional [ bool ] = None ,
tensor_parallel_devices : Optional [ Sequence [ torch . device ] ] = None ,
tensor_parallel_devices : Optional [ Sequence [ torch . device ] ] = None ,
skip_reachability_check : bool = False ,
skip_reachability_check : bool = False ,
use_relay : bool = True ,
use_auto_relay : bool = True ,
* * kwargs ,
* * kwargs ,
) :
) :
""" Create a server with one or more bloom blocks. See run_server.py for documentation. """
""" Create a server with one or more bloom blocks. See run_server.py for documentation. """
@ -117,14 +119,20 @@ class Server:
)
)
self . module_uids = [ f " { self . prefix } . { block_index } " for block_index in range ( self . block_config . n_layer ) ]
self . module_uids = [ f " { self . prefix } . { block_index } " for block_index in range ( self . block_config . n_layer ) ]
self . dht = DHT ( initial_peers = initial_peers , start = True , num_workers = self . block_config . n_layer , * * kwargs )
self . dht = DHT (
initial_peers = initial_peers ,
start = True ,
num_workers = self . block_config . n_layer ,
use_relay = use_relay ,
use_auto_relay = use_auto_relay ,
* * kwargs ,
)
visible_maddrs_str = [ str ( a ) for a in self . dht . get_visible_maddrs ( ) ]
visible_maddrs_str = [ str ( a ) for a in self . dht . get_visible_maddrs ( ) ]
if initial_peers == PUBLIC_INITIAL_PEERS :
if initial_peers == PUBLIC_INITIAL_PEERS :
logger . info ( f " Connecting to the public swarm, peer_id = { self . dht . peer_id } " )
logger . info ( f " Connecting to the public swarm, peer_id = { self . dht . peer_id } " )
if not skip_reachability_check :
self . _check_reachability ( )
else :
else :
logger . info ( f " Running DHT node on { visible_maddrs_str } , initial peers = { initial_peers } " )
logger . info ( f " Running DHT node on { visible_maddrs_str } , initial peers = { initial_peers } " )
self . need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
if device is None :
if device is None :
device = " cuda " if torch . cuda . is_available ( ) else " cpu "
device = " cuda " if torch . cuda . is_available ( ) else " cpu "
@ -196,35 +204,14 @@ class Server:
self . stop = threading . Event ( )
self . stop = threading . Event ( )
def _check_reachability ( self ) :
try :
r = requests . get ( f " http://health.petals.ml/api/v1/is_reachable/ { self . dht . peer_id } " , timeout = 10 )
r . raise_for_status ( )
response = r . json ( )
except Exception as e :
logger . warning ( f " Skipping reachability check because health.petals.ml is down: { repr ( e ) } " )
return
if not response [ " success " ] :
# This happens only if health.petals.ml is up and explicitly told us that we are unreachable
raise RuntimeError (
f " Server is not reachable from the Internet: \n \n "
f " { response [ ' message ' ] } \n \n "
f " You need to fix your port forwarding and/or firewall settings. How to do that: \n \n "
f " 1. Choose a specific port for the Petals server, for example, 31337. \n "
f " 2. Ensure that this port is accessible from the Internet and not blocked by your firewall. \n "
f " 3. Add these arguments to explicitly announce your IP address and port to other peers: \n "
f " python -m petals.cli.run_server ... --public_ip { response [ ' your_ip ' ] } --port 31337 \n "
f " 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH \n "
)
logger . info ( " Server is reachable from the Internet, it will appear at http://health.petals.ml soon " )
def _choose_num_blocks ( self ) - > int :
def _choose_num_blocks ( self ) - > int :
assert (
assert (
self . converted_model_name_or_path == " bigscience/bloom-petals "
self . converted_model_name_or_path == " bigscience/bloom-petals "
) , " If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually "
) , " If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually "
assert self . device . type == " cuda " , " If you run a non-GPU server, please specify --num_blocks manually "
assert self . device . type == " cuda " , (
" GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
" CPU-only servers in the public swarm are discouraged since they are much slower "
)
num_devices = len ( self . tensor_parallel_devices ) if self . tensor_parallel_devices else 1
num_devices = len ( self . tensor_parallel_devices ) if self . tensor_parallel_devices else 1
if num_devices > 1 :
if num_devices > 1 :
@ -287,6 +274,7 @@ class Server:
use_auth_token = self . use_auth_token ,
use_auth_token = self . use_auth_token ,
load_in_8bit = self . load_in_8bit ,
load_in_8bit = self . load_in_8bit ,
tensor_parallel_devices = self . tensor_parallel_devices ,
tensor_parallel_devices = self . tensor_parallel_devices ,
need_reachability_check = self . need_reachability_check ,
start = True ,
start = True ,
)
)
try :
try :
@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread):
use_auth_token : Optional [ str ] ,
use_auth_token : Optional [ str ] ,
load_in_8bit : bool ,
load_in_8bit : bool ,
tensor_parallel_devices : Sequence [ torch . device ] ,
tensor_parallel_devices : Sequence [ torch . device ] ,
need_reachability_check : bool ,
* * kwargs ,
* * kwargs ,
) - > ModuleContainer :
) - > ModuleContainer :
module_uids = [ f " { prefix } . { block_index } " for block_index in block_indices ]
module_uids = [ f " { prefix } . { block_index } " for block_index in block_indices ]
@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread):
min_batch_size = min_batch_size ,
min_batch_size = min_batch_size ,
max_batch_size = max_batch_size ,
max_batch_size = max_batch_size ,
)
)
if need_reachability_check :
check_reachability ( dht . peer_id )
except :
except :
logger . debug ( " Shutting down backends " )
logger . debug ( " Shutting down backends " )
for backend in blocks . values ( ) :
for backend in blocks . values ( ) :