|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import gc
|
|
|
|
|
import math
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
import random
|
|
|
|
|
import threading
|
|
|
|
@ -130,7 +131,9 @@ class Server:
|
|
|
|
|
)
|
|
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
|
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
|
|
|
if num_blocks is None and block_indices is None:
|
|
|
|
|
num_blocks = self._choose_num_blocks()
|
|
|
|
|
if block_indices is not None:
|
|
|
|
|
try:
|
|
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
|
|
@ -167,6 +170,26 @@ class Server:
|
|
|
|
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
|
|
|
|
def _choose_num_blocks(self) -> int:
|
|
|
|
|
assert (
|
|
|
|
|
self.converted_model_name_or_path == "bigscience/bloom-petals"
|
|
|
|
|
), "If you use a model other than bigscience/bloom-petals, please specify --num blocks manually"
|
|
|
|
|
assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
|
|
|
|
|
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
total_memory_gib = torch.cuda.get_device_properties(self.device).total_memory / gib
|
|
|
|
|
block_size_gib = 176 / 70 + 0.5
|
|
|
|
|
if not self.load_in_8bit:
|
|
|
|
|
block_size_gib *= 2 if self.dtype in (torch.float16, torch.bfloat16) else 4
|
|
|
|
|
num_blocks = math.floor((total_memory_gib - 2) / block_size_gib)
|
|
|
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
|
|
|
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
|
|
|
|
)
|
|
|
|
|
return num_blocks
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
while True:
|
|
|
|
|
block_indices = self._choose_blocks()
|
|
|
|
|