|
|
|
@ -50,18 +50,19 @@ class Server:
|
|
|
|
|
initial_peers: List[str],
|
|
|
|
|
dht_prefix: Optional[str],
|
|
|
|
|
converted_model_name_or_path: str,
|
|
|
|
|
public_name: Optional[str] = None,
|
|
|
|
|
throughput: Union[float, str],
|
|
|
|
|
num_blocks: Optional[int] = None,
|
|
|
|
|
block_indices: Optional[str] = None,
|
|
|
|
|
num_handlers: int = 8,
|
|
|
|
|
inference_max_length: Optional[int] = None,
|
|
|
|
|
min_batch_size: int = 1,
|
|
|
|
|
max_batch_size: int = 2048,
|
|
|
|
|
inference_max_length: int = 2048,
|
|
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
|
|
attn_cache_tokens: Optional[int] = None,
|
|
|
|
|
torch_dtype: str = "auto",
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
|
attn_cache_tokens: int = 8192,
|
|
|
|
|
alloc_timeout: float = 5,
|
|
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
@ -93,8 +94,6 @@ class Server:
|
|
|
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
|
|
|
|
|
|
|
self.num_handlers = num_handlers
|
|
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
|
|
self.inference_max_length = inference_max_length
|
|
|
|
|
self.compression = compression
|
|
|
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
|
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
|
|
@ -177,8 +176,19 @@ class Server:
|
|
|
|
|
self.quant_type = quant_type
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
|
|
|
|
is_multiquery_attn = self.block_config.num_key_value_groups > 1
|
|
|
|
|
if max_batch_size is None:
|
|
|
|
|
max_batch_size = 8192 if is_multiquery_attn else 2048
|
|
|
|
|
if inference_max_length is None:
|
|
|
|
|
inference_max_length = 8192 if is_multiquery_attn else 2048
|
|
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
|
|
self.inference_max_length = inference_max_length
|
|
|
|
|
|
|
|
|
|
# For attention cache in GPU or RAM
|
|
|
|
|
if attn_cache_tokens is None:
|
|
|
|
|
attn_cache_tokens = 32768 if is_multiquery_attn else 2048
|
|
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
|
cache_values_per_block //= self.block_config.num_key_value_groups
|
|
|
|
|
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
# For disk cache
|
|
|
|
@ -222,8 +232,9 @@ class Server:
|
|
|
|
|
throughput_info = {"throughput": throughput}
|
|
|
|
|
self.server_info = ServerInfo(
|
|
|
|
|
state=ServerState.JOINING,
|
|
|
|
|
adapters=tuple(adapters),
|
|
|
|
|
public_name=public_name,
|
|
|
|
|
version=petals.__version__,
|
|
|
|
|
adapters=tuple(adapters),
|
|
|
|
|
torch_dtype=str(torch_dtype).replace("torch.", ""),
|
|
|
|
|
quant_type=quant_type.name.lower(),
|
|
|
|
|
using_relay=self.dht.client_mode,
|
|
|
|
@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
|
|
self.dht = dht
|
|
|
|
|
self.server_info = server_info
|
|
|
|
|
self.memory_cache = memory_cache
|
|
|
|
|
|
|
|
|
|
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
|
|
|
|
self.bytes_per_token //= block_config.num_key_value_groups
|
|
|
|
|
|
|
|
|
|
self.update_period = update_period
|
|
|
|
|
self.expiration = expiration
|
|
|
|
|
self.trigger = threading.Event()
|
|
|
|
|