diff --git a/src/petals/__init__.py b/src/petals/__init__.py index b72776c..53bdd51 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -11,7 +11,7 @@ from petals.models import * from petals.utils import * from petals.utils.logging import initialize_logs as _initialize_logs -__version__ = "1.2.0.dev3" +__version__ = "1.2.0.dev4" if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 8820dd2..abd3faf 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -25,6 +25,8 @@ def main(): help="path or name of a pretrained model, converted with cli/convert_model.py") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") + parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard") + group = parser.add_mutually_exclusive_group(required=False) group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") group.add_argument("--use_auth_token", action="store_true", dest="token", @@ -59,16 +61,22 @@ def main(): parser.add_argument('--num_handlers', type=int, default=8, required=False, help='server will use this many processes to handle incoming requests') - parser.add_argument('--min_batch_size', type=int, default=1, - help='Minimum required batch size for all operations (in total tokens)') - parser.add_argument('--max_batch_size', type=int, default=2048, - help='The total number of tokens in the same batch will not exceed this value') parser.add_argument('--prefetch_batches', type=int, default=1, required=False, help='Pre-form this many subsequent batches while GPU is processing the current one') parser.add_argument('--sender_threads', type=int, default=1, required=False, help='Use this many threads to pass results/exceptions from Runtime to Pools') - parser.add_argument('--inference_max_length', type=int, default=2048, - help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') + + parser.add_argument('--inference_max_length', type=int, default=None, + help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. ' + 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--min_batch_size', type=int, default=1, + help='Minimum required batch size for all operations (in total tokens)') + parser.add_argument('--max_batch_size', type=int, default=None, + help='The total number of tokens in the same batch will not exceed this value. ' + 'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)') + parser.add_argument('--attn_cache_tokens', type=int, default=None, + help='The number of past attention key/value pairs that will be stored between inference steps. ' + 'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)') parser.add_argument('--cache_dir', type=str, default=None, help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') @@ -86,9 +94,6 @@ def main(): parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--attn_cache_tokens', type=int, default=8192, - help='The number of past attention key/value pairs that will be stored between inference steps. ' - 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') parser.add_argument('--alloc_timeout', type=float, default=5, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index e3a3e03..38d706f 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -27,12 +27,14 @@ class ServerInfo: state: ServerState throughput: RPS + public_name: Optional[str] = None + version: Optional[str] = None + network_rps: Optional[RPS] = None forward_rps: Optional[RPS] = None inference_rps: Optional[RPS] = None adapters: Sequence[str] = () - version: Optional[str] = None torch_dtype: Optional[str] = None quant_type: Optional[str] = None using_relay: Optional[bool] = None diff --git a/src/petals/models/bloom/config.py b/src/petals/models/bloom/config.py index 23521fc..494c187 100644 --- a/src/petals/models/bloom/config.py +++ b/src/petals/models/bloom/config.py @@ -18,6 +18,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM attn_class = BloomAttention block_prefix = "h" + num_key_value_groups = 1 + @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs diff --git a/src/petals/models/llama/block.py b/src/petals/models/llama/block.py index 2f07188..55f659a 100644 --- a/src/petals/models/llama/block.py +++ b/src/petals/models/llama/block.py @@ -73,7 +73,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer): ) -> Tuple[torch.Tensor]: key_states, value_states = key_value key_states = key_states.permute(0, 2, 1) - key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + key_states = key_states.view( + batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) value_states = value_states.view(*key_states.shape) return (key_states, value_states) @@ -81,7 +83,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer): self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int ) -> Tuple[torch.Tensor]: key_states, value_states = key_value - value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim) + value_states = value_states.view( + batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim + ) key_states = key_states.view(*value_states.shape) key_states = key_states.permute(0, 2, 1) return (key_states, value_states) diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index b21fa9a..241525a 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -18,13 +18,17 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM attn_class = LlamaAttention block_prefix = "model.layers" + @property + def num_key_value_groups(self): + return self.num_attention_heads // self.num_key_value_heads + @classmethod def from_pretrained( cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs ): logger.info( - "LLaMA is available solely for non-commercial research purposes. " - "Make sure you follow the terms of use: https://bit.ly/llama-license" + "Make sure you follow the LLaMA's terms of use: " + "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1" ) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) @@ -34,4 +38,8 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM if not dht_prefix.endswith("-hf"): dht_prefix += "-hf" logger.info(f"Using DHT prefix: {dht_prefix}") - return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + + result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) + config = result[0] if isinstance(result, tuple) else result + config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization + return result diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4220546..d61470a 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -81,6 +81,7 @@ class TransformerBackend(ModuleBackend): head_dim = self.config.hidden_size // self.config.num_attention_heads cache_tensors = [] for device, num_heads in zip(self.module.devices, self.shard_num_heads): + num_heads //= self.config.num_key_value_groups keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) cache_tensors.extend((keys, values)) @@ -123,8 +124,10 @@ class TransformerBackend(ModuleBackend): """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) for i in range(len(key_cache)): - key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] - value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim] + key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] + # shape: [batch * num_kv_heads, head_dim, kv_length] + value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] + # shape: [batch * num_kv_heads, kv_length, head_dim] layer_past = tuple(chain(*zip(key_cache, value_cache))) return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past @@ -132,7 +135,7 @@ class TransformerBackend(ModuleBackend): self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int ): """Writes new key/value tensors back into cache, works in-place""" - _batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape + _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): new_key = new_key.view(*cache_key.shape[:3], new_length) cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] diff --git a/src/petals/server/from_pretrained.py b/src/petals/server/from_pretrained.py index 950746e..2a2560b 100644 --- a/src/petals/server/from_pretrained.py +++ b/src/petals/server/from_pretrained.py @@ -23,6 +23,7 @@ from petals.constants import DTYPE_MAP from petals.server.block_utils import resolve_block_dtype from petals.utils.auto_config import AutoDistributedConfig from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.hf_auth import always_needs_auth logger = get_logger(__name__) @@ -86,6 +87,9 @@ def _load_state_dict_from_repo( cache_dir: str, max_disk_space: Optional[int] = None, ) -> StateDict: + if always_needs_auth(model_name) and token is None: + token = True + index_file = get_file_from_repo( model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir ) diff --git a/src/petals/server/reachability.py b/src/petals/server/reachability.py index 03e01fc..ff8dd14 100644 --- a/src/petals/server/reachability.py +++ b/src/petals/server/reachability.py @@ -145,8 +145,7 @@ class ReachabilityProtocol(ServicerBase): async with protocol.serve(common_p2p): await protocol._stop.wait() except Exception as e: - logger.warning(f"Reachability service failed: {repr(e)}") - logger.debug("See detailed traceback below:", exc_info=True) + logger.debug("Reachability service failed:", exc_info=True) if not ready.done(): ready.set_exception(e) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index ccc5292..947dbd8 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -50,18 +50,19 @@ class Server: initial_peers: List[str], dht_prefix: Optional[str], converted_model_name_or_path: str, + public_name: Optional[str] = None, throughput: Union[float, str], num_blocks: Optional[int] = None, block_indices: Optional[str] = None, num_handlers: int = 8, + inference_max_length: Optional[int] = None, min_batch_size: int = 1, - max_batch_size: int = 2048, - inference_max_length: int = 2048, + max_batch_size: Optional[int] = None, + attn_cache_tokens: Optional[int] = None, torch_dtype: str = "auto", revision: Optional[str] = None, cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - attn_cache_tokens: int = 8192, alloc_timeout: float = 5, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, @@ -93,8 +94,6 @@ class Server: self.converted_model_name_or_path = converted_model_name_or_path self.num_handlers = num_handlers - self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size - self.inference_max_length = inference_max_length self.compression = compression self.stats_report_interval, self.update_period = stats_report_interval, update_period self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads @@ -177,8 +176,19 @@ class Server: self.quant_type = quant_type logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") + is_multiquery_attn = self.block_config.num_key_value_groups > 1 + if max_batch_size is None: + max_batch_size = 8192 if is_multiquery_attn else 2048 + if inference_max_length is None: + inference_max_length = 8192 if is_multiquery_attn else 2048 + self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size + self.inference_max_length = inference_max_length + # For attention cache in GPU or RAM + if attn_cache_tokens is None: + attn_cache_tokens = 32768 if is_multiquery_attn else 2048 cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens + cache_values_per_block //= self.block_config.num_key_value_groups self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 # For disk cache @@ -222,8 +232,9 @@ class Server: throughput_info = {"throughput": throughput} self.server_info = ServerInfo( state=ServerState.JOINING, - adapters=tuple(adapters), + public_name=public_name, version=petals.__version__, + adapters=tuple(adapters), torch_dtype=str(torch_dtype).replace("torch.", ""), quant_type=quant_type.name.lower(), using_relay=self.dht.client_mode, @@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread): self.dht = dht self.server_info = server_info self.memory_cache = memory_cache + self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 + self.bytes_per_token //= block_config.num_key_value_groups + self.update_period = update_period self.expiration = expiration self.trigger = threading.Event() diff --git a/src/petals/utils/auto_config.py b/src/petals/utils/auto_config.py index f587051..13c7298 100644 --- a/src/petals/utils/auto_config.py +++ b/src/petals/utils/auto_config.py @@ -1,8 +1,12 @@ +import os +import re from dataclasses import dataclass -from typing import Optional, Type +from typing import Optional, Type, Union from transformers import AutoConfig, PretrainedConfig, PreTrainedModel +from petals.utils.hf_auth import always_needs_auth + @dataclass class _ModelClasses: @@ -26,8 +30,11 @@ class _AutoDistributedBase: _mapping_field = None # Should be defined in child classes @classmethod - def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: - config = AutoConfig.from_pretrained(*args, **kwargs) + def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig: + if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs: + kwargs["token"] = True + + config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs) if config.model_type not in _CLASS_MAPPING: raise ValueError(f"Petals does not support model type {config.model_type}") @@ -35,7 +42,7 @@ class _AutoDistributedBase: if proper_cls is None: raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}") - return proper_cls.from_pretrained(*args, **kwargs) + return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs) class AutoDistributedConfig(_AutoDistributedBase): diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index f8a4637..94d3e29 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -2,6 +2,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor parallelism """ import re +from enum import Enum from typing import Optional, Sequence import tensor_parallel as tp @@ -11,12 +12,16 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig -from petals.utils.misc import QuantType - use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + def convert_block( block: nn.Module, block_index: int, diff --git a/src/petals/utils/hf_auth.py b/src/petals/utils/hf_auth.py new file mode 100644 index 0000000..6446b89 --- /dev/null +++ b/src/petals/utils/hf_auth.py @@ -0,0 +1,7 @@ +import os +from typing import Union + + +def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool: + loading_from_repo = model_name is not None and not os.path.isdir(model_name) + return loading_from_repo and model_name.startswith("meta-llama/Llama-2-") diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 99b246c..2f67202 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -1,14 +1,5 @@ -from enum import Enum - import torch - -class QuantType(Enum): - NONE = 0 - INT8 = 1 # 8-bit as in the LLM.int8() paper - NF4 = 2 # 4-bit as in the QLoRA paper - - DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index de48cd2..da25623 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -17,8 +17,8 @@ from safetensors.torch import load_file from transformers.utils import get_file_from_repo from petals.server.block_utils import resolve_block_dtype +from petals.utils.convert_block import QuantType from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for -from petals.utils.misc import QuantType logger = get_logger(__name__)