@ -9,7 +9,6 @@ import time
from typing import Dict , List , Optional , Sequence , Union
from typing import Dict , List , Optional , Sequence , Union
import numpy as np
import numpy as np
import psutil
import torch
import torch
from hivemind import DHT , MAX_DHT_TIME_DISCREPANCY_SECONDS , BatchTensorDescriptor , get_dht_time
from hivemind import DHT , MAX_DHT_TIME_DISCREPANCY_SECONDS , BatchTensorDescriptor , get_dht_time
from hivemind . moe . server . layers import add_custom_models_from_file
from hivemind . moe . server . layers import add_custom_models_from_file
@ -28,7 +27,7 @@ from petals.server.block_utils import get_block_size
from petals . server . handler import TransformerConnectionHandler
from petals . server . handler import TransformerConnectionHandler
from petals . server . memory_cache import MemoryCache
from petals . server . memory_cache import MemoryCache
from petals . server . reachability import check_reachability
from petals . server . reachability import check_reachability
from petals . server . throughput import get_ host_throughput
from petals . server . throughput import get_ dtype_name, get_ host_throughput
from petals . utils . convert_block import check_device_balance , convert_block
from petals . utils . convert_block import check_device_balance , convert_block
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
from petals . utils . disk_cache import DEFAULT_CACHE_DIR
@ -146,12 +145,6 @@ class Server:
assert torch_dtype in DTYPE_MAP . values ( ) , f " torch_dtype must be one of { list ( DTYPE_MAP . values ( ) ) } "
assert torch_dtype in DTYPE_MAP . values ( ) , f " torch_dtype must be one of { list ( DTYPE_MAP . values ( ) ) } "
self . torch_dtype = torch_dtype
self . torch_dtype = torch_dtype
if load_in_8bit is None :
load_in_8bit = device . type == " cuda "
if load_in_8bit :
logger . info ( " Model weights will be loaded in 8-bit format " )
self . load_in_8bit = load_in_8bit
if tensor_parallel_devices is None :
if tensor_parallel_devices is None :
tensor_parallel_devices = ( device , )
tensor_parallel_devices = ( device , )
self . tensor_parallel_devices = tuple ( map ( torch . device , tensor_parallel_devices ) )
self . tensor_parallel_devices = tuple ( map ( torch . device , tensor_parallel_devices ) )
@ -159,6 +152,17 @@ class Server:
logger . info ( f " Model weights will be split between { ' , ' . join ( tensor_parallel_devices ) } " )
logger . info ( f " Model weights will be split between { ' , ' . join ( tensor_parallel_devices ) } " )
check_device_balance ( self . tensor_parallel_devices )
check_device_balance ( self . tensor_parallel_devices )
if load_in_8bit is None :
load_in_8bit = device . type == " cuda "
if load_in_8bit and len ( self . tensor_parallel_devices ) > 1 :
load_in_8bit = False
logger . warning (
" Tensor parallelism doesn ' t work properly with 8-bit weights yet, loading weights in 16-bit. "
" You can explicitly set `--load_in_8bit True` to override this "
)
self . load_in_8bit = load_in_8bit
logger . info ( f " Model weights will be loaded in { get_dtype_name ( torch_dtype , load_in_8bit ) } format " )
assert num_blocks is None or block_indices is None , " Please specify num_blocks or block_indices, not both "
assert num_blocks is None or block_indices is None , " Please specify num_blocks or block_indices, not both "
if num_blocks is None and block_indices is None :
if num_blocks is None and block_indices is None :
num_blocks = self . _choose_num_blocks ( )
num_blocks = self . _choose_num_blocks ( )
@ -167,8 +171,7 @@ class Server:
first_block_index , last_block_index = block_indices . split ( " : " )
first_block_index , last_block_index = block_indices . split ( " : " )
first_block_index , last_block_index = map ( int , map ( str . strip , ( first_block_index , last_block_index ) ) )
first_block_index , last_block_index = map ( int , map ( str . strip , ( first_block_index , last_block_index ) ) )
except Exception as e :
except Exception as e :
logger . error ( f " Failed to parse --block_indices ( { e } ), must be start:end (e.g. 0:18) " )
raise ValueError ( f " Failed to parse `--block_indices { block_indices } `, must be start:end (e.g. 0:18) " )
raise
block_indices = range ( first_block_index , last_block_index )
block_indices = range ( first_block_index , last_block_index )
num_blocks = len ( block_indices )
num_blocks = len ( block_indices )
self . strict_block_indices , self . num_blocks = block_indices , num_blocks
self . strict_block_indices , self . num_blocks = block_indices , num_blocks
@ -301,10 +304,6 @@ class Server:
del self . module_container
del self . module_container
gc . collect ( ) # In particular, this closes unused file descriptors
gc . collect ( ) # In particular, this closes unused file descriptors
cur_proc = psutil . Process ( )
num_fds = [ proc . num_fds ( ) for proc in [ cur_proc ] + cur_proc . children ( recursive = True ) ]
logger . info ( f " Cleaning up, left { sum ( num_fds ) } open file descriptors " )
if self . device . type == " cuda " :
if self . device . type == " cuda " :
torch . cuda . empty_cache ( )
torch . cuda . empty_cache ( )