diff --git a/src/petals/server/block_utils.py b/src/petals/server/block_utils.py new file mode 100644 index 0000000..ce5f678 --- /dev/null +++ b/src/petals/server/block_utils.py @@ -0,0 +1,48 @@ +from typing import Optional, Union + +import torch +from accelerate import init_empty_weights + +from petals.bloom import BloomBlock, BloomConfig + + +def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]: + """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise.""" + + if dtype == "auto" or dtype is None: + dtype = config.torch_dtype + if dtype == "auto" or dtype is None: + dtype = torch.float32 + return dtype + + +def get_block_size( + config: BloomConfig, + location: str, + *, + dtype: Optional[Union[str, torch.dtype]] = None, + load_in_8bit: Optional[bool] = None, + layer_index: int = 0, + eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc. +) -> int: + if location == "memory": + assert ( + dtype is not None and load_in_8bit is not None + ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations' + + with init_empty_weights(): + block = BloomBlock(config, layer_index) + n_params = sum(param.numel() for param in block.parameters()) + + if location == "memory" and load_in_8bit: + # Note: We may need a larger eps here for models of size < 1B + return n_params * (1 + eps) + + if location == "memory": + dtype = resolve_block_dtype(config, dtype) + elif location == "disk": + dtype = resolve_block_dtype(config, "auto") + else: + raise ValueError('get_block_size() expects location to be "memory" or "disk"') + + return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps)) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 7c12f48..97b03e0 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -24,6 +24,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState from petals.dht_utils import declare_active_modules, get_remote_module_infos from petals.server import block_selection from petals.server.backend import TransformerBackend +from petals.server.block_utils import get_block_size from petals.server.handler import TransformerConnectionHandler from petals.server.memory_cache import MemoryCache from petals.server.throughput import get_host_throughput @@ -125,6 +126,11 @@ class Server: device = torch.device(device) self.device = device + if isinstance(torch_dtype, str): + torch_dtype = DTYPE_MAP[torch_dtype] + assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + self.torch_dtype = torch_dtype + if load_in_8bit is None: load_in_8bit = device.type == "cuda" if load_in_8bit: @@ -152,11 +158,6 @@ class Server: logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB") self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) - if isinstance(torch_dtype, str): - torch_dtype = DTYPE_MAP[torch_dtype] - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" - self.torch_dtype = torch_dtype - assert isinstance(throughput, float) or throughput in ["auto", "eval"] if throughput in ["auto", "eval"]: throughput = get_host_throughput( @@ -181,19 +182,19 @@ class Server: ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually" assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + total_memory = torch.cuda.get_device_properties(self.device).total_memory + block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit) gib = 1024**3 - total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib - block_size_gib = 176 / 70 + 0.5 - if not self.load_in_8bit: - block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4 - num_blocks = math.floor((total_memory_gib - 2) / block_size_gib) + attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size + + num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" logger.info( f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" ) - return num_blocks + return min(num_blocks, self.block_config.n_layer) def run(self): while True: @@ -231,10 +232,13 @@ class Server: while True: timeout = random.random() * 2 * self.mean_balance_check_period - # TODO: Follow ModuleContainer status (to restart/stop if it crashes) if self.stop.wait(timeout): return + if not self.module_container.handlers_alive: + logger.warning("One of connection handlers crashed, restarting the server") + break + if self._should_choose_other_blocks(): logger.info("Swarm is imbalanced, server will load other blocks") break # Stop serving this set of modules @@ -466,6 +470,10 @@ class ModuleContainer(threading.Thread): """ return self.runtime.ready # mp.Event that is true if self is ready to process batches + @property + def handlers_alive(self) -> bool: + return all(handler.is_alive() for handler in self.conn_handlers) + def shutdown(self): """ Gracefully terminate the container, process-safe. diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index fc08eba..6d22bb5 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -13,6 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from petals.bloom.block import BloomBlock from petals.bloom.model import BloomConfig from petals.bloom.ops import build_alibi_tensor +from petals.server.block_utils import resolve_block_dtype from petals.utils.convert_8bit import replace_8bit_linear from petals.utils.disk_cache import DEFAULT_CACHE_DIR @@ -29,11 +30,7 @@ def get_host_throughput( force_eval: bool = False, cache_dir: Optional[str] = None, ) -> float: - # Resolve default dtypes - if dtype == "auto" or dtype is None: - dtype = config.torch_dtype - if dtype == "auto" or dtype is None: - dtype = torch.float32 + dtype = resolve_block_dtype(config, dtype) if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR