From 47d50e1e2938f8a0174caf670b25dea5345c6830 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Mon, 23 Oct 2023 05:26:40 +0600 Subject: [PATCH] Improve default arguments for clients and servers (#530) This PR updates multiple default arguments in clients and servers: 1. **The client defaults to `torch_dtype=torch.float32` instead of `torch_dtype="auto"`.** The old default was to load weights in the dtype they are saved in (usually bfloat16/float16), which caused issues when the client was run on CPU (the default unless you call `.cuda()`). Specifically, bfloat16 is slow on most CPUs (unless a CPU supports AVX512) and float16 can't be run natively and leads to an exception. This default was a legacy of the earliest Petals versions designed to run BLOOM - its embeddings were so big that they didn't fit into RAM in float32 (e.g., in Colab). The newer models don't have this issue. In contrast, the new default leads to good speed on all CPUs and is consistent with PyTorch and HF Transformers. Also, the client now shows "bfloat16 on non-AVX512 CPU" in all cases (previously this warning was shown only if the machine has enough RAM to fit float32 weights, which could hide the crucial reason of inference being slow). **Note:** This change is backward-incompatible, so we have to increase at least the minor package version (2.2.0 -> 2.3.0.dev0). 2. **The server uses 2x smaller `--attn_cache_tokens`.** The old default led to loading 39 (out of 80) or 78 (out of 80) blocks for popular models on some GPU types, which visibly slowed down inference due to an excess network hop. It was also leaving too much cache, so that inference slowed down much before the cache is used. The new default leads to more efficient block layouts and makes the inference routing algorithm choose alternative paths through other servers when a particular server already has enough active inference sessions (= its cache is full). 3. **The client's max number of retries can be limited by the `PETALS_MAX_RETRIES` env var.** This is to limit `ClientConfig.max_retries` in tests, so we see tracebacks instead of retrying indefinitely in case of errors. --- .github/workflows/run-tests.yaml | 3 +++ src/petals/__init__.py | 2 +- src/petals/cli/run_server.py | 6 +++--- src/petals/client/config.py | 6 +++++- src/petals/client/from_pretrained.py | 10 +--------- src/petals/client/lm_head.py | 12 +++++------- src/petals/server/server.py | 2 +- 7 files changed, 19 insertions(+), 22 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 05cebdd..b9dcc01 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -102,6 +102,9 @@ jobs: export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely + export PETALS_MAX_RETRIES=10 + pytest tests --durations=0 --durations-min=1.0 -v # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index f513f65..8671fc2 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,7 +17,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.2.0" +__version__ = "2.3.0.dev0" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 94f5c2e..5208438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -70,17 +70,17 @@ def main(): parser.add_argument('--inference_max_length', type=int, default=None, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') diff --git a/src/petals/client/config.py b/src/petals/client/config.py index e255024..a2f8f42 100644 --- a/src/petals/client/config.py +++ b/src/petals/client/config.py @@ -1,10 +1,14 @@ import dataclasses +import os from typing import Optional, Sequence, Union from hivemind import PeerID from petals.constants import PUBLIC_INITIAL_PEERS +_max_retries = os.getenv("PETALS_MAX_RETRIES") +DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None + @dataclasses.dataclass class ClientConfig: @@ -21,7 +25,7 @@ class ClientConfig: request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py index f2c88d2..4b9d8e5 100644 --- a/src/petals/client/from_pretrained.py +++ b/src/petals/client/from_pretrained.py @@ -6,7 +6,6 @@ import tempfile from contextvars import ContextVar from typing import List, Optional, Tuple, Union -import torch from hivemind.utils.logging import get_logger from transformers import BloomPreTrainedModel, modeling_utils @@ -22,21 +21,14 @@ class FromPretrainedMixin: model_name_or_path: Union[str, os.PathLike, None], *args, low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): model_name_or_path = get_compatible_model_repo(model_name_or_path) if low_cpu_mem_usage is None: low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" with ignore_keys(cls._keys_to_ignore_on_load_unexpected): - return super().from_pretrained( - model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs - ) + return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index cbea89d..bc0e293 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -1,8 +1,7 @@ import dataclasses import platform -from typing import Optional, Union +from typing import Union -import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -68,11 +67,10 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: - logger.warning( - "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " - "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" - ) + logger.warning( + "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " + "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" + ) self._bf16_warning_shown = True hidden_states = hidden_states.float() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 82388aa..45884e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -203,7 +203,7 @@ class Server: # For attention cache in GPU or RAM if attn_cache_tokens is None: - attn_cache_tokens = 32768 if is_multiquery_attn else 8192 + attn_cache_tokens = 16384 if is_multiquery_attn else 4096 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)