Improve block size calculations (#149)

pull/150/head
Alexander Borzunov 1 year ago committed by GitHub
parent f42e559c77
commit 83d9493b6c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,48 @@
from typing import Optional, Union
import torch
from accelerate import init_empty_weights
from petals.bloom import BloomBlock, BloomConfig
def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
"""If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
if dtype == "auto" or dtype is None:
dtype = config.torch_dtype
if dtype == "auto" or dtype is None:
dtype = torch.float32
return dtype
def get_block_size(
config: BloomConfig,
location: str,
*,
dtype: Optional[Union[str, torch.dtype]] = None,
load_in_8bit: Optional[bool] = None,
layer_index: int = 0,
eps: float = 0.01, # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
) -> int:
if location == "memory":
assert (
dtype is not None and load_in_8bit is not None
), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
with init_empty_weights():
block = BloomBlock(config, layer_index)
n_params = sum(param.numel() for param in block.parameters())
if location == "memory" and load_in_8bit:
# Note: We may need a larger eps here for models of size < 1B
return n_params * (1 + eps)
if location == "memory":
dtype = resolve_block_dtype(config, dtype)
elif location == "disk":
dtype = resolve_block_dtype(config, "auto")
else:
raise ValueError('get_block_size() expects location to be "memory" or "disk"')
return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))

@ -24,6 +24,7 @@ from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
from petals.dht_utils import declare_active_modules, get_remote_module_infos
from petals.server import block_selection
from petals.server.backend import TransformerBackend
from petals.server.block_utils import get_block_size
from petals.server.handler import TransformerConnectionHandler
from petals.server.memory_cache import MemoryCache
from petals.server.throughput import get_host_throughput
@ -125,6 +126,11 @@ class Server:
device = torch.device(device)
self.device = device
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit:
@ -152,11 +158,6 @@ class Server:
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
@ -181,19 +182,19 @@ class Server:
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
total_memory = torch.cuda.get_device_properties(self.device).total_memory
block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
gib = 1024**3
total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib
block_size_gib = 176 / 70 + 0.5
if not self.load_in_8bit:
block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4
num_blocks = math.floor((total_memory_gib - 2) / block_size_gib)
attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size
num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
logger.info(
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
)
return num_blocks
return min(num_blocks, self.block_config.n_layer)
def run(self):
while True:
@ -231,10 +232,13 @@ class Server:
while True:
timeout = random.random() * 2 * self.mean_balance_check_period
# TODO: Follow ModuleContainer status (to restart/stop if it crashes)
if self.stop.wait(timeout):
return
if not self.module_container.handlers_alive:
logger.warning("One of connection handlers crashed, restarting the server")
break
if self._should_choose_other_blocks():
logger.info("Swarm is imbalanced, server will load other blocks")
break # Stop serving this set of modules
@ -466,6 +470,10 @@ class ModuleContainer(threading.Thread):
"""
return self.runtime.ready # mp.Event that is true if self is ready to process batches
@property
def handlers_alive(self) -> bool:
return all(handler.is_alive() for handler in self.conn_handlers)
def shutdown(self):
"""
Gracefully terminate the container, process-safe.

@ -13,6 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from petals.bloom.block import BloomBlock
from petals.bloom.model import BloomConfig
from petals.bloom.ops import build_alibi_tensor
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_8bit import replace_8bit_linear
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@ -29,11 +30,7 @@ def get_host_throughput(
force_eval: bool = False,
cache_dir: Optional[str] = None,
) -> float:
# Resolve default dtypes
if dtype == "auto" or dtype is None:
dtype = config.torch_dtype
if dtype == "auto" or dtype is None:
dtype = torch.float32
dtype = resolve_block_dtype(config, dtype)
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR

Loading…
Cancel
Save