|
|
|
@ -56,7 +56,7 @@ class Server:
|
|
|
|
|
revision: str = "main",
|
|
|
|
|
cache_dir: Optional[str] = None,
|
|
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
|
|
attn_cache_size: Optional[int] = None,
|
|
|
|
|
attn_cache_tokens: int = 8192,
|
|
|
|
|
alloc_timeout: float = 60,
|
|
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
|
|
compression=CompressionType.NONE,
|
|
|
|
@ -148,9 +148,7 @@ class Server:
|
|
|
|
|
device = torch.device(device.type, index=0)
|
|
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
|
if isinstance(torch_dtype, str):
|
|
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
|
|
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
|
|
torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
|
|
self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
|
|
|
|
|
|
|
|
|
|
if tensor_parallel_devices is None:
|
|
|
|
@ -165,6 +163,9 @@ class Server:
|
|
|
|
|
self.load_in_8bit = load_in_8bit
|
|
|
|
|
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
|
|
|
|
|
|
|
|
|
|
max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
|
self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
|
|
|
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
|
|
|
if num_blocks is None and block_indices is None:
|
|
|
|
|
num_blocks = self._choose_num_blocks()
|
|
|
|
@ -179,13 +180,10 @@ class Server:
|
|
|
|
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
|
|
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
if attn_cache_size is None:
|
|
|
|
|
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
|
|
|
|
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
|
|
|
|
|
|
|
|
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
|
|
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
|
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
|
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
|
|
|
|
|
|
|
|
|
self.alloc_timeout = alloc_timeout
|
|
|
|
|
if cache_dir is None:
|
|
|
|
|
cache_dir = DEFAULT_CACHE_DIR
|
|
|
|
|
self.cache_dir = cache_dir
|
|
|
|
@ -236,10 +234,9 @@ class Server:
|
|
|
|
|
|
|
|
|
|
# The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
|
|
|
|
|
gib = 1024**3
|
|
|
|
|
attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size
|
|
|
|
|
autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward
|
|
|
|
|
|
|
|
|
|
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
|
|
|
|
|
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
|
|
|
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
@ -256,7 +253,7 @@ class Server:
|
|
|
|
|
prefix=self.prefix,
|
|
|
|
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
|
|
block_config=self.block_config,
|
|
|
|
|
attn_cache_size=self.attn_cache_size,
|
|
|
|
|
attn_cache_bytes=self.attn_cache_bytes,
|
|
|
|
|
alloc_timeout=self.alloc_timeout,
|
|
|
|
|
throughput=self.throughput,
|
|
|
|
|
block_indices=block_indices,
|
|
|
|
@ -356,7 +353,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
prefix: str,
|
|
|
|
|
converted_model_name_or_path: str,
|
|
|
|
|
block_config: BloomConfig,
|
|
|
|
|
attn_cache_size: int,
|
|
|
|
|
attn_cache_bytes: int,
|
|
|
|
|
alloc_timeout: float,
|
|
|
|
|
throughput: float,
|
|
|
|
|
block_indices: List[int],
|
|
|
|
@ -390,7 +387,7 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
|
|
|
|
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
|
|
|
|
|
|
|
|
|
|
memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
|
|
|
|
|
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
|
|
blocks = {}
|
|
|
|
|
try:
|
|
|
|
|
for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
|
|