|
|
|
@ -28,7 +28,7 @@ from petals.server.memory_cache import MemoryCache
|
|
|
|
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
|
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
|
|
from petals.utils.convert_block import check_device_balance, convert_block
|
|
|
|
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
|
from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
|
|
|
@ -75,7 +75,7 @@ class Server:
|
|
|
|
|
mean_balance_check_period: float = 120,
|
|
|
|
|
mean_block_selection_delay: float = 2.5,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
load_in_8bit: Optional[bool] = None,
|
|
|
|
|
quant_type: Optional[QuantType] = None,
|
|
|
|
|
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
|
|
|
|
skip_reachability_check: bool = False,
|
|
|
|
|
dht_client_mode: Optional[bool] = None,
|
|
|
|
@ -154,8 +154,8 @@ class Server:
|
|
|
|
|
device = torch.device(device.type, index=0)
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
|
|
self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
|
|
|
|
|
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
|
|
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
|
|
|
|
if tensor_parallel_devices is None:
|
|
|
|
|
tensor_parallel_devices = (device,)
|
|
|
|
@ -164,10 +164,10 @@ class Server:
|
|
|
|
|
logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
|
|
|
|
|
check_device_balance(self.tensor_parallel_devices)
|
|
|
|
|
|
|
|
|
|
if load_in_8bit is None:
|
|
|
|
|
load_in_8bit = device.type == "cuda"
|
|
|
|
|
self.load_in_8bit = load_in_8bit
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
|
|
|
|
|
if quant_type is None:
|
|
|
|
|
quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
|
|
|
|
|
self.quant_type = quant_type
|
|
|
|
|
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
|
|
|
|
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
|
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
@ -203,7 +203,7 @@ class Server:
|
|
|
|
|
device,
|
|
|
|
|
torch_dtype,
|
|
|
|
|
num_blocks=num_blocks,
|
|
|
|
|
load_in_8bit=load_in_8bit,
|
|
|
|
|
quant_type=quant_type,
|
|
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
|
|
force_eval=(throughput == "eval"),
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
@ -237,11 +237,11 @@ class Server:
|
|
|
|
|
else:
|
|
|
|
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
|
|
|
|
|
|
|
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
|
|
|
|
|
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
|
|
|
|
|
|
|
|
|
|
# The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward
|
|
|
|
|
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
|
|
|
|
autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
|
|
|
|
|
|
|
|
|
|
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
|
|
|
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
@ -284,7 +284,7 @@ class Server:
|
|
|
|
|
sender_threads=self.sender_threads,
|
|
|
|
|
revision=self.revision,
|
|
|
|
|
use_auth_token=self.use_auth_token,
|
|
|
|
|
load_in_8bit=self.load_in_8bit,
|
|
|
|
|
quant_type=self.quant_type,
|
|
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
|
|
should_validate_reachability=self.should_validate_reachability,
|
|
|
|
|
start=True,
|
|
|
|
@ -377,7 +377,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
expiration: Optional[float],
|
|
|
|
|
revision: Optional[str],
|
|
|
|
|
use_auth_token: Optional[str],
|
|
|
|
|
load_in_8bit: bool,
|
|
|
|
|
quant_type: QuantType,
|
|
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
|
|
should_validate_reachability: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
@ -411,7 +411,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
max_disk_space=max_disk_space,
|
|
|
|
|
)
|
|
|
|
|
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
|
|
|
|
|
block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
|
|
|
|
|
blocks[module_uid] = TransformerBackend(
|
|
|
|
|
module_uid,
|
|
|
|
|
block,
|
|
|
|
|