From fc6722576bd745e6b3d413047504dad5d9c67dba Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 2 Dec 2022 23:17:44 +0400 Subject: [PATCH] Choose --num_blocks for bigscience/bloom-petals automatically (#119) --- src/petals/server/server.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index a817851..7527b76 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -1,6 +1,7 @@ from __future__ import annotations import gc +import math import multiprocessing as mp import random import threading @@ -130,7 +131,9 @@ class Server: ) self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] - assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" + if num_blocks is None and block_indices is None: + num_blocks = self._choose_num_blocks() if block_indices is not None: try: first_block_index, last_block_index = block_indices.split(":") @@ -167,6 +170,26 @@ class Server: self.stop = threading.Event() + def _choose_num_blocks(self) -> int: + assert ( + self.converted_model_name_or_path == "bigscience/bloom-petals" + ), "If you use a model other than bigscience/bloom-petals, please specify --num blocks manually" + assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually" + + gib = 1024**3 + total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib + block_size_gib = 176 / 70 + 0.5 + if not self.load_in_8bit: + block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4 + num_blocks = math.floor((total_memory_gib - 2) / block_size_gib) + assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" + + logger.info( + f"Server will fill all your GPU memory with {num_blocks} transformer blocks. " + f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually" + ) + return num_blocks + def run(self): while True: block_indices = self._choose_blocks()