Improve default arguments for clients and servers (#530)

This PR updates multiple default arguments in clients and servers:

1. **The client defaults to `torch_dtype=torch.float32` instead of `torch_dtype="auto"`.**

    The old default was to load weights in the dtype they are saved in (usually bfloat16/float16), which caused issues when the client was run on CPU (the default unless you call `.cuda()`). Specifically, bfloat16 is slow on most CPUs (unless a CPU supports AVX512) and float16 can't be run natively and leads to an exception. This default was a legacy of the earliest Petals versions designed to run BLOOM - its embeddings were so big that they didn't fit into RAM in float32 (e.g., in Colab). The newer models don't have this issue.

    In contrast, the new default leads to good speed on all CPUs and is consistent with PyTorch and HF Transformers. Also, the client now shows "bfloat16 on non-AVX512 CPU" in all cases (previously this warning was shown only if the machine has enough RAM to fit float32 weights, which could hide the crucial reason of inference being slow).

    **Note:** This change is backward-incompatible, so we have to increase at least the minor package version (2.2.0 -> 2.3.0.dev0).

2. **The server uses 2x smaller `--attn_cache_tokens`.**

    The old default led to loading 39 (out of 80) or 78 (out of 80) blocks for popular models on some GPU types, which visibly slowed down inference due to an excess network hop. It was also leaving too much cache, so that inference slowed down much before the cache is used.

    The new default leads to more efficient block layouts and makes the inference routing algorithm choose alternative paths through other servers when a particular server already has enough active inference sessions (= its cache is full).

3. **The client's max number of retries can be limited by the `PETALS_MAX_RETRIES` env var.**

    This is to limit `ClientConfig.max_retries` in tests, so we see tracebacks instead of retrying indefinitely in case of errors.
pull/531/head
Alexander Borzunov 6 months ago committed by GitHub
parent ae19b65095
commit 47d50e1e29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -102,6 +102,9 @@ jobs:
export no_proxy=*
export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely
export PETALS_MAX_RETRIES=10
pytest tests --durations=0 --durations-min=1.0 -v
# [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers)

@ -17,7 +17,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "2.2.0"
__version__ = "2.3.0.dev0"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -70,17 +70,17 @@ def main():
parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

@ -1,10 +1,14 @@
import dataclasses
import os
from typing import Optional, Sequence, Union
from hivemind import PeerID
from petals.constants import PUBLIC_INITIAL_PEERS
_max_retries = os.getenv("PETALS_MAX_RETRIES")
DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None
@dataclasses.dataclass
class ClientConfig:
@ -21,7 +25,7 @@ class ClientConfig:
request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests
update_period: float = 60 # refresh DHT information once in this many seconds
max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf)
max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf)
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds

@ -6,7 +6,6 @@ import tempfile
from contextvars import ContextVar
from typing import List, Optional, Tuple, Union
import torch
from hivemind.utils.logging import get_logger
from transformers import BloomPreTrainedModel, modeling_utils
@ -22,21 +21,14 @@ class FromPretrainedMixin:
model_name_or_path: Union[str, os.PathLike, None],
*args,
low_cpu_mem_usage: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
**kwargs,
):
model_name_or_path = get_compatible_model_repo(model_name_or_path)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if torch_dtype is None:
# torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
# torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
torch_dtype = "auto"
with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
return super().from_pretrained(
model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
)
return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
"low_cpu_mem_usage(`bool`, *optional*)",

@ -1,8 +1,7 @@
import dataclasses
import platform
from typing import Optional, Union
from typing import Union
import psutil
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
@ -68,11 +67,10 @@ class LMHead(nn.Module):
assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
if not self._bf16_warning_shown:
if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
logger.warning(
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
logger.warning(
"Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
"To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
)
self._bf16_warning_shown = True
hidden_states = hidden_states.float()

@ -203,7 +203,7 @@ class Server:
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
attn_cache_tokens = 16384 if is_multiquery_attn else 4096
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)

Loading…
Cancel
Save