diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 05cebdd..b9dcc01 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -102,6 +102,9 @@ jobs: export no_proxy=* export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES + # Limit default ClientConfig.max_retries to see tracebacks instead of retrying indefinitely + export PETALS_MAX_RETRIES=10 + pytest tests --durations=0 --durations-min=1.0 -v # [Step 3] Check if benchmarks work (their results here are meaningless since it's a tiny swarm of CPU servers) diff --git a/src/petals/__init__.py b/src/petals/__init__.py index f513f65..8671fc2 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -17,7 +17,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "2.2.0" +__version__ = "2.3.0.dev0" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 94f5c2e..5208438 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -70,17 +70,17 @@ def main(): parser.add_argument('--inference_max_length', type=int, default=None, help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--min_batch_size', type=int, default=1, help='Minimum required batch size for all operations (in total tokens)') parser.add_argument('--max_batch_size', type=int, default=None, help='The total number of tokens in the same batch will not exceed this value. ' - 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others') parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024, help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks') parser.add_argument('--attn_cache_tokens', type=int, default=None, help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') + 'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') diff --git a/src/petals/client/config.py b/src/petals/client/config.py index e255024..a2f8f42 100644 --- a/src/petals/client/config.py +++ b/src/petals/client/config.py @@ -1,10 +1,14 @@ import dataclasses +import os from typing import Optional, Sequence, Union from hivemind import PeerID from petals.constants import PUBLIC_INITIAL_PEERS +_max_retries = os.getenv("PETALS_MAX_RETRIES") +DEFAULT_MAX_RETRIES = int(_max_retries) if isinstance(_max_retries, str) else None + @dataclasses.dataclass class ClientConfig: @@ -21,7 +25,7 @@ class ClientConfig: request_timeout: float = 3 * 60 # timeout for forward/backward/inference requests update_period: float = 60 # refresh DHT information once in this many seconds - max_retries: Optional[int] = None # max number retries before the client raises an exception (default: inf) + max_retries: Optional[int] = DEFAULT_MAX_RETRIES # max number of retries before an exception (default: inf) min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds diff --git a/src/petals/client/from_pretrained.py b/src/petals/client/from_pretrained.py index f2c88d2..4b9d8e5 100644 --- a/src/petals/client/from_pretrained.py +++ b/src/petals/client/from_pretrained.py @@ -6,7 +6,6 @@ import tempfile from contextvars import ContextVar from typing import List, Optional, Tuple, Union -import torch from hivemind.utils.logging import get_logger from transformers import BloomPreTrainedModel, modeling_utils @@ -22,21 +21,14 @@ class FromPretrainedMixin: model_name_or_path: Union[str, os.PathLike, None], *args, low_cpu_mem_usage: Optional[bool] = None, - torch_dtype: Optional[Union[str, torch.dtype]] = None, **kwargs, ): model_name_or_path = get_compatible_model_repo(model_name_or_path) if low_cpu_mem_usage is None: low_cpu_mem_usage = True - if torch_dtype is None: - # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast, - # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights. - torch_dtype = "auto" with ignore_keys(cls._keys_to_ignore_on_load_unexpected): - return super().from_pretrained( - model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs - ) + return super().from_pretrained(model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs) from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace( "low_cpu_mem_usage(`bool`, *optional*)", diff --git a/src/petals/client/lm_head.py b/src/petals/client/lm_head.py index cbea89d..bc0e293 100644 --- a/src/petals/client/lm_head.py +++ b/src/petals/client/lm_head.py @@ -1,8 +1,7 @@ import dataclasses import platform -from typing import Optional, Union +from typing import Union -import psutil import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -68,11 +67,10 @@ class LMHead(nn.Module): assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive" if not self._bf16_warning_shown: - if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total: - logger.warning( - "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " - "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" - ) + logger.warning( + "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. " + "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)" + ) self._bf16_warning_shown = True hidden_states = hidden_states.float() diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 82388aa..45884e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -203,7 +203,7 @@ class Server: # For attention cache in GPU or RAM if attn_cache_tokens is None: - attn_cache_tokens = 32768 if is_multiquery_attn else 8192 + attn_cache_tokens = 16384 if is_multiquery_attn else 4096 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)