@ -64,14 +64,14 @@ class Server:
expiration : Optional [ float ] = None ,
expiration : Optional [ float ] = None ,
request_timeout : float = 3 * 60 ,
request_timeout : float = 3 * 60 ,
session_timeout : float = 30 * 60 ,
session_timeout : float = 30 * 60 ,
step_timeout : float = 5 * 60,
step_timeout : float = 60,
prefetch_batches : int = 1 ,
prefetch_batches : int = 1 ,
sender_threads : int = 1 ,
sender_threads : int = 1 ,
balance_quality : float = 0.75 ,
balance_quality : float = 0.75 ,
mean_balance_check_period : float = 60 ,
mean_balance_check_period : float = 60 ,
mean_block_selection_delay : float = 0.5 ,
mean_block_selection_delay : float = 0.5 ,
use_auth_token : Optional [ str ] = None ,
use_auth_token : Optional [ str ] = None ,
load_in_8bit : bool = Fals e,
load_in_8bit : Optional [ bool ] = Non e,
* * kwargs ,
* * kwargs ,
) :
) :
""" Create a server with one or more bloom blocks. See run_server.py for documentation. """
""" Create a server with one or more bloom blocks. See run_server.py for documentation. """
@ -81,12 +81,10 @@ class Server:
self . min_batch_size , self . max_batch_size = min_batch_size , max_batch_size
self . min_batch_size , self . max_batch_size = min_batch_size , max_batch_size
self . inference_max_length = inference_max_length
self . inference_max_length = inference_max_length
self . cache_dir = cache_dir
self . cache_dir = cache_dir
self . attn_cache_size = attn_cache_size
self . compression = compression
self . compression = compression
self . stats_report_interval , self . update_period = stats_report_interval , update_period
self . stats_report_interval , self . update_period = stats_report_interval , update_period
self . prefetch_batches , self . sender_threads = prefetch_batches , sender_threads
self . prefetch_batches , self . sender_threads = prefetch_batches , sender_threads
self . use_auth_token = use_auth_token
self . use_auth_token = use_auth_token
self . load_in_8bit = load_in_8bit
if custom_module_path is not None :
if custom_module_path is not None :
add_custom_models_from_file ( custom_module_path )
add_custom_models_from_file ( custom_module_path )
@ -114,15 +112,16 @@ class Server:
else :
else :
logger . info ( f " Running DHT node on { visible_maddrs_str } , initial peers = { initial_peers } " )
logger . info ( f " Running DHT node on { visible_maddrs_str } , initial peers = { initial_peers } " )
device = device or ( " cuda " if torch . cuda . is_available ( ) else " cpu " )
if device is None :
device = " cuda " if torch . cuda . is_available ( ) else " cpu "
device = torch . device ( device )
self . device = device
self . device = device
self . memory_cache = MemoryCache ( device , attn_cache_size , alloc_timeout )
if load_in_8bit is None :
load_in_8bit = device . type == " cuda "
if isinstance ( torch_dtype , str ) :
if load_in_8bit :
torch_dtype = DTYPE_MAP [ torch_dtype ]
logger . info ( " Model weights will be loaded in 8-bit format " )
assert torch_dtype in DTYPE_MAP . values ( ) , f " torch_dtype must be one of { list ( DTYPE_MAP . values ( ) ) } "
self . load_in_8bit = load_in_8bit
self . torch_dtype = torch_dtype
self . block_config = BloomConfig . from_pretrained (
self . block_config = BloomConfig . from_pretrained (
converted_model_name_or_path ,
converted_model_name_or_path ,
@ -131,13 +130,6 @@ class Server:
)
)
self . module_uids = [ f " { self . prefix } . { block_index } " for block_index in range ( self . block_config . n_layer ) ]
self . module_uids = [ f " { self . prefix } . { block_index } " for block_index in range ( self . block_config . n_layer ) ]
assert isinstance ( throughput , float ) or throughput in [ " auto " , " eval " ]
if throughput in [ " auto " , " eval " ] :
throughput = get_host_throughput (
self . block_config , device , torch_dtype , load_in_8bit = load_in_8bit , force_eval = ( throughput == " eval " )
)
self . throughput = throughput
assert ( block_indices is None ) != ( num_blocks is None ) , " please specify num_blocks or block_indices, not both "
assert ( block_indices is None ) != ( num_blocks is None ) , " please specify num_blocks or block_indices, not both "
if block_indices is not None :
if block_indices is not None :
try :
try :
@ -147,7 +139,28 @@ class Server:
logger . error ( f " Failed to parse --block_indices ( { e } ), must be start:end (e.g. 0:18) " )
logger . error ( f " Failed to parse --block_indices ( { e } ), must be start:end (e.g. 0:18) " )
raise
raise
block_indices = range ( first_block_index , last_block_index )
block_indices = range ( first_block_index , last_block_index )
num_blocks = len ( block_indices )
self . strict_block_indices , self . num_blocks = block_indices , num_blocks
self . strict_block_indices , self . num_blocks = block_indices , num_blocks
gib = 1024 * * 3
if attn_cache_size is None :
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self . block_config . hidden_size / 14336
logger . info ( f " Attention cache for all blocks will consume up to { attn_cache_size / gib : .2f } GiB " )
self . memory_cache = MemoryCache ( device , attn_cache_size , alloc_timeout )
if isinstance ( torch_dtype , str ) :
torch_dtype = DTYPE_MAP [ torch_dtype ]
assert torch_dtype in DTYPE_MAP . values ( ) , f " torch_dtype must be one of { list ( DTYPE_MAP . values ( ) ) } "
self . torch_dtype = torch_dtype
assert isinstance ( throughput , float ) or throughput in [ " auto " , " eval " ]
if throughput in [ " auto " , " eval " ] :
throughput = get_host_throughput (
self . block_config , device , torch_dtype , load_in_8bit = load_in_8bit , force_eval = ( throughput == " eval " )
)
self . throughput = throughput
self . balance_quality = balance_quality
self . balance_quality = balance_quality
self . mean_balance_check_period = mean_balance_check_period
self . mean_balance_check_period = mean_balance_check_period
self . mean_block_selection_delay = mean_block_selection_delay
self . mean_block_selection_delay = mean_block_selection_delay
@ -213,7 +226,6 @@ class Server:
def _choose_blocks ( self ) - > List [ int ] :
def _choose_blocks ( self ) - > List [ int ] :
if self . strict_block_indices is not None :
if self . strict_block_indices is not None :
return self . strict_block_indices
return self . strict_block_indices
assert self . num_blocks is not None
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
# this delay decreases the probability of a race condition while choosing the best blocks to serve.