Make server use smart defaults (#115)

Summary:

```python
parser.add_argument('--attn_cache_size', type=str, default=None,
                    help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
                         'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
                         'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
                         'The latter is the hidden size of the bigscience/bloom-petals model.')

parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
                    help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
                    help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=60,
                    help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")

parser.add_argument('--load_in_8bit', type=bool, default=None,
                    help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
```

Co-authored-by: justheuristic <justheuristic@gmail.com>
pull/117/head
Alexander Borzunov 2 years ago committed by GitHub
parent 9e11f73242
commit 643a054170
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,8 +5,5 @@ export HIVEMIND_COLORS=true
while true; do while true; do
pkill -f p2p pkill -f p2p
pkill -f run_server pkill -f run_server
python -m petals.cli.run_server bigscience/bloom-petals \ python -m petals.cli.run_server bigscience/bloom-petals "$@" 2>&1 | tee log_`date '+%F_%H:%M:%S'`
--block_indices $1 \
--torch_dtype bfloat16 --load_in_8bit \
--attn_cache_size $2 2>&1 | tee log_`date '+%F_%H:%M:%S'`
done done

@ -55,8 +55,10 @@ def main():
help="Use this dtype to store block weights and do computations. " help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.") "By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_size', type=str, default=None, parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference' help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB') 'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
'The latter is the hidden size of the bigscience/bloom-petals model.')
parser.add_argument('--alloc_timeout', type=float, default=60, parser.add_argument('--alloc_timeout', type=float, default=60,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request') 'before rejecting the request')
@ -76,11 +78,11 @@ def main():
parser.add_argument('--expiration', type=float, required=False, default=None, parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds') help='DHT entries will expire after this many seconds')
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60, parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request') help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60, parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
help='Timeout for the whole inference session') help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60, parser.add_argument('--step_timeout', type=float, required=False, default=60,
help="Timeout for waiting the next step's inputs inside an inference session") help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group() group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS, group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
@ -106,7 +108,8 @@ def main():
help="Check the swarm's balance every N seconds (and rebalance it if necessary)") help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained") parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.') parser.add_argument('--load_in_8bit', type=bool, default=None,
help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
# fmt:on # fmt:on
args = vars(parser.parse_args()) args = vars(parser.parse_args())

@ -64,14 +64,14 @@ class Server:
expiration: Optional[float] = None, expiration: Optional[float] = None,
request_timeout: float = 3 * 60, request_timeout: float = 3 * 60,
session_timeout: float = 30 * 60, session_timeout: float = 30 * 60,
step_timeout: float = 5 * 60, step_timeout: float = 60,
prefetch_batches: int = 1, prefetch_batches: int = 1,
sender_threads: int = 1, sender_threads: int = 1,
balance_quality: float = 0.75, balance_quality: float = 0.75,
mean_balance_check_period: float = 60, mean_balance_check_period: float = 60,
mean_block_selection_delay: float = 0.5, mean_block_selection_delay: float = 0.5,
use_auth_token: Optional[str] = None, use_auth_token: Optional[str] = None,
load_in_8bit: bool = False, load_in_8bit: Optional[bool] = None,
**kwargs, **kwargs,
): ):
"""Create a server with one or more bloom blocks. See run_server.py for documentation.""" """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -81,12 +81,10 @@ class Server:
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length self.inference_max_length = inference_max_length
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.attn_cache_size = attn_cache_size
self.compression = compression self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
self.use_auth_token = use_auth_token self.use_auth_token = use_auth_token
self.load_in_8bit = load_in_8bit
if custom_module_path is not None: if custom_module_path is not None:
add_custom_models_from_file(custom_module_path) add_custom_models_from_file(custom_module_path)
@ -114,15 +112,16 @@ class Server:
else: else:
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}") logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
device = device or ("cuda" if torch.cuda.is_available() else "cpu") if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
self.device = device self.device = device
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout) if load_in_8bit is None:
load_in_8bit = device.type == "cuda"
if isinstance(torch_dtype, str): if load_in_8bit:
torch_dtype = DTYPE_MAP[torch_dtype] logger.info("Model weights will be loaded in 8-bit format")
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" self.load_in_8bit = load_in_8bit
self.torch_dtype = torch_dtype
self.block_config = BloomConfig.from_pretrained( self.block_config = BloomConfig.from_pretrained(
converted_model_name_or_path, converted_model_name_or_path,
@ -131,13 +130,6 @@ class Server:
) )
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)] self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
)
self.throughput = throughput
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both" assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
if block_indices is not None: if block_indices is not None:
try: try:
@ -147,7 +139,28 @@ class Server:
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)") logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
raise raise
block_indices = range(first_block_index, last_block_index) block_indices = range(first_block_index, last_block_index)
num_blocks = len(block_indices)
self.strict_block_indices, self.num_blocks = block_indices, num_blocks self.strict_block_indices, self.num_blocks = block_indices, num_blocks
gib = 1024**3
if attn_cache_size is None:
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
self.torch_dtype = torch_dtype
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
throughput = get_host_throughput(
self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
)
self.throughput = throughput
self.balance_quality = balance_quality self.balance_quality = balance_quality
self.mean_balance_check_period = mean_balance_check_period self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay self.mean_block_selection_delay = mean_block_selection_delay
@ -213,7 +226,6 @@ class Server:
def _choose_blocks(self) -> List[int]: def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None: if self.strict_block_indices is not None:
return self.strict_block_indices return self.strict_block_indices
assert self.num_blocks is not None
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time, # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
# this delay decreases the probability of a race condition while choosing the best blocks to serve. # this delay decreases the probability of a race condition while choosing the best blocks to serve.

@ -6,6 +6,7 @@ import tempfile
import time import time
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Union
import torch import torch
from hivemind.utils.logging import get_logger, use_hivemind_log_handler from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@ -26,7 +27,7 @@ DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
def get_host_throughput( def get_host_throughput(
config: BloomConfig, config: BloomConfig,
device: torch.device, device: torch.device,
torch_dtype: torch.dtype, dtype: Union[str, torch.dtype],
*, *,
load_in_8bit: bool, load_in_8bit: bool,
force_eval: bool = False, force_eval: bool = False,
@ -42,7 +43,7 @@ def get_host_throughput(
cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}" cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}" cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
cache_key += f"_dtype_{_get_dtype_name(torch_dtype, load_in_8bit)}" cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
cache = {} cache = {}
try: try:
@ -55,7 +56,7 @@ def get_host_throughput(
cache = {} cache = {}
if cache_key not in cache: if cache_key not in cache:
cache[cache_key] = measure_throughput_info(config, device, torch_dtype, load_in_8bit=load_in_8bit) cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
try: try:
os.makedirs(cache_path.parent, exist_ok=True) os.makedirs(cache_path.parent, exist_ok=True)
@ -70,7 +71,7 @@ def get_host_throughput(
def measure_throughput_info( def measure_throughput_info(
config: BloomConfig, config: BloomConfig,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: Union[str, torch.dtype],
*, *,
load_in_8bit: bool, load_in_8bit: bool,
) -> float: ) -> float:
@ -106,7 +107,7 @@ def measure_network_rps(config: BloomConfig) -> float:
def measure_compute_rps( def measure_compute_rps(
config: BloomConfig, config: BloomConfig,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: Union[str, torch.dtype],
*, *,
load_in_8bit: bool, load_in_8bit: bool,
n_tokens: int = 16, n_tokens: int = 16,
@ -114,7 +115,10 @@ def measure_compute_rps(
layer_index: int = 0, layer_index: int = 0,
) -> float: ) -> float:
with torch.inference_mode(): with torch.inference_mode():
block = BloomBlock(config, layer_index).to(dtype) block = BloomBlock(config, layer_index)
if dtype != "auto":
block = block.to(dtype)
input_dtype = block.input_layernorm.weight.dtype
if load_in_8bit: if load_in_8bit:
block = replace_8bit_linear(block) block = replace_8bit_linear(block)
block = block.to(device) block = block.to(device)
@ -122,8 +126,8 @@ def measure_compute_rps(
cache = None cache = None
elapsed = 0 elapsed = 0
for step in range(n_steps + 1): for step in range(n_steps + 1):
dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype) dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype) alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
start_time = time.perf_counter() start_time = time.perf_counter()
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache) _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)

Loading…
Cancel
Save