Make server use smart defaults (#115)

Summary:

```python
parser.add_argument('--attn_cache_size', type=str, default=None,
                    help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
                         'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
                         'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
                         'The latter is the hidden size of the bigscience/bloom-petals model.')

parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
                    help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
                    help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=60,
                    help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")

parser.add_argument('--load_in_8bit', type=bool, default=None,
                    help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
```

Co-authored-by: justheuristic <justheuristic@gmail.com>
pull/117/head
Alexander Borzunov 1 year ago committed by GitHub
parent 9e11f73242
commit 643a054170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,8 +5,5 @@ export HIVEMIND_COLORS=true
while true; do
pkill -f p2p
pkill -f run_server
python -m petals.cli.run_server bigscience/bloom-petals \
--block_indices $1 \
--torch_dtype bfloat16 --load_in_8bit \
--attn_cache_size $2 2>&1 | tee log_`date '+%F_%H:%M:%S'`
python -m petals.cli.run_server bigscience/bloom-petals "$@" 2>&1 | tee log_`date '+%F_%H:%M:%S'`
done

@ -55,8 +55,10 @@ def main():
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference'
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
'The latter is the hidden size of the bigscience/bloom-petals model.')
parser.add_argument('--alloc_timeout', type=float, default=60,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
@ -76,11 +78,11 @@ def main():
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
help='Timeout for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
help="Timeout for waiting the next step's inputs inside an inference session")
help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=60,
help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
@ -106,7 +108,8 @@ def main():
help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
parser.add_argument('--load_in_8bit', type=bool, default=None,
help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
# fmt:on
args = vars(parser.parse_args())

@ -64,14 +64,14 @@ class Server:
expiration: Optional[float] = None,
request_timeout: float = 3 * 60,
session_timeout: float = 30 * 60,
step_timeout: float = 5 * 60,
step_timeout: float = 60,
prefetch_batches: int = 1,
sender_threads: int = 1,
balance_quality: float = 0.75,
mean_balance_check_period: float = 60,
mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None,
load_in_8bit: bool = False,
load_in_8bit: Optional[bool] = None,
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -81,12 +81,10 @@ class Server:
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.cache_dir = cache_dir
self.attn_cache_size = attn_cache_size
self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
self.use_auth_token = use_auth_token
self.load_in_8bit = load_in_8bit
if custom_module_path is not None:
add_custom_models_from_file(custom_module_path)
@ -114,15 +112,16 @@ class Server:
else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
self.device = device
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if load_in_8bit:
logger.info("Model weights will be loaded in 8-bit format")
self.load_in_8bit = load_in_8bit
self.block_config = BloomConfig.from_pretrained(
converted_model_name_or_path,
@ -131,13 +130,6 @@ class Server:
)
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
)
self.throughput = throughput
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
if block_indices is not None:
try:
@ -147,7 +139,28 @@ class Server:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
raise
block_indices = range(first_block_index, last_block_index)
num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
gib = 1024**3
if attn_cache_size is None:
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
)
self.throughput = throughput
self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
@ -213,7 +226,6 @@ class Server:
def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
return self.strict_block_indices
assert self.num_blocks is not None
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve.

@ -6,6 +6,7 @@ import tempfile
import time
from hashlib import sha256
from pathlib import Path
from typing import Union
import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@ -26,7 +27,7 @@ DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
def get_host_throughput(
config: BloomConfig,
device: torch.device,
torch_dtype: torch.dtype,
dtype: Union[str, torch.dtype],
*,
load_in_8bit: bool,
force_eval: bool = False,
@ -42,7 +43,7 @@ def get_host_throughput(
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{_get_dtype_name(torch_dtype, load_in_8bit)}"
cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
cache = {}
try:
@ -55,7 +56,7 @@ def get_host_throughput(
cache = {}
if cache_key not in cache:
cache[cache_key] = measure_throughput_info(config, device, torch_dtype, load_in_8bit=load_in_8bit)
cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
try:
os.makedirs(cache_path.parent, exist_ok=True)
@ -70,7 +71,7 @@ def get_host_throughput(
def measure_throughput_info(
config: BloomConfig,
device: torch.device,
dtype: torch.dtype,
dtype: Union[str, torch.dtype],
*,
load_in_8bit: bool,
) -> float:
@ -106,7 +107,7 @@ def measure_network_rps(config: BloomConfig) -> float:
def measure_compute_rps(
config: BloomConfig,
device: torch.device,
dtype: torch.dtype,
dtype: Union[str, torch.dtype],
*,
load_in_8bit: bool,
n_tokens: int = 16,
@ -114,7 +115,10 @@ def measure_compute_rps(
layer_index: int = 0,
) -> float:
with torch.inference_mode():
block = BloomBlock(config, layer_index).to(dtype)
block = BloomBlock(config, layer_index)
if dtype != "auto":
block = block.to(dtype)
input_dtype = block.input_layernorm.weight.dtype
if load_in_8bit:
block = replace_8bit_linear(block)
block = block.to(device)
@ -122,8 +126,8 @@ def measure_compute_rps(
cache = None
elapsed = 0
for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)

Loading…
Cancel
Save