Support Llama 2 (#379)

pull/383/head
Alexander Borzunov 11 months ago committed by GitHub
parent 3218534745
commit 057a2fb5de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import * from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.2.0.dev3" __version__ = "1.2.0.dev4"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -25,6 +25,8 @@ def main():
help="path or name of a pretrained model, converted with cli/convert_model.py") help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path") group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard")
group = parser.add_mutually_exclusive_group(required=False) group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()") group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token", group.add_argument("--use_auth_token", action="store_true", dest="token",
@ -59,16 +61,22 @@ def main():
parser.add_argument('--num_handlers', type=int, default=8, required=False, parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests') help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=2048,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False, parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one') help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False, parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools') help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=2048,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens') parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--cache_dir', type=str, default=None, parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.') help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
@ -86,9 +94,6 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. " help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.") "By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_tokens', type=int, default=8192,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
parser.add_argument('--alloc_timeout', type=float, default=5, parser.add_argument('--alloc_timeout', type=float, default=5,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request') 'before rejecting the request')

@ -27,12 +27,14 @@ class ServerInfo:
state: ServerState state: ServerState
throughput: RPS throughput: RPS
public_name: Optional[str] = None
version: Optional[str] = None
network_rps: Optional[RPS] = None network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None inference_rps: Optional[RPS] = None
adapters: Sequence[str] = () adapters: Sequence[str] = ()
version: Optional[str] = None
torch_dtype: Optional[str] = None torch_dtype: Optional[str] = None
quant_type: Optional[str] = None quant_type: Optional[str] = None
using_relay: Optional[bool] = None using_relay: Optional[bool] = None

@ -18,6 +18,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM
attn_class = BloomAttention attn_class = BloomAttention
block_prefix = "h" block_prefix = "h"
num_key_value_groups = 1
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs

@ -73,7 +73,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
key_states, value_states = key_value key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1) key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim) key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape) value_states = value_states.view(*key_states.shape)
return (key_states, value_states) return (key_states, value_states)
@ -81,7 +83,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
key_states, value_states = key_value key_states, value_states = key_value
value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim) value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape) key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1) key_states = key_states.permute(0, 2, 1)
return (key_states, value_states) return (key_states, value_states)

@ -18,13 +18,17 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
attn_class = LlamaAttention attn_class = LlamaAttention
block_prefix = "model.layers" block_prefix = "model.layers"
@property
def num_key_value_groups(self):
return self.num_attention_heads // self.num_key_value_heads
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
): ):
logger.info( logger.info(
"LLaMA is available solely for non-commercial research purposes. " "Make sure you follow the LLaMA's terms of use: "
"Make sure you follow the terms of use: https://bit.ly/llama-license" "https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
) )
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path) loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
@ -34,4 +38,8 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
if not dht_prefix.endswith("-hf"): if not dht_prefix.endswith("-hf"):
dht_prefix += "-hf" dht_prefix += "-hf"
logger.info(f"Using DHT prefix: {dht_prefix}") logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
return result

@ -81,6 +81,7 @@ class TransformerBackend(ModuleBackend):
head_dim = self.config.hidden_size // self.config.num_attention_heads head_dim = self.config.hidden_size // self.config.num_attention_heads
cache_tensors = [] cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads): for device, num_heads in zip(self.module.devices, self.shard_num_heads):
num_heads //= self.config.num_key_value_groups
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device) keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device) values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values)) cache_tensors.extend((keys, values))
@ -123,8 +124,10 @@ class TransformerBackend(ModuleBackend):
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past""" """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2]) key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
for i in range(len(key_cache)): for i in range(len(key_cache)):
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length] key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim] # shape: [batch * num_kv_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]
# shape: [batch * num_kv_heads, kv_length, head_dim]
layer_past = tuple(chain(*zip(key_cache, value_cache))) layer_past = tuple(chain(*zip(key_cache, value_cache)))
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
@ -132,7 +135,7 @@ class TransformerBackend(ModuleBackend):
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
): ):
"""Writes new key/value tensors back into cache, works in-place""" """Writes new key/value tensors back into cache, works in-place"""
_batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape _batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]): for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
new_key = new_key.view(*cache_key.shape[:3], new_length) new_key = new_key.view(*cache_key.shape[:3], new_length)
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length] cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]

@ -23,6 +23,7 @@ from petals.constants import DTYPE_MAP
from petals.server.block_utils import resolve_block_dtype from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
logger = get_logger(__name__) logger = get_logger(__name__)
@ -86,6 +87,9 @@ def _load_state_dict_from_repo(
cache_dir: str, cache_dir: str,
max_disk_space: Optional[int] = None, max_disk_space: Optional[int] = None,
) -> StateDict: ) -> StateDict:
if always_needs_auth(model_name) and token is None:
token = True
index_file = get_file_from_repo( index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
) )

@ -145,8 +145,7 @@ class ReachabilityProtocol(ServicerBase):
async with protocol.serve(common_p2p): async with protocol.serve(common_p2p):
await protocol._stop.wait() await protocol._stop.wait()
except Exception as e: except Exception as e:
logger.warning(f"Reachability service failed: {repr(e)}") logger.debug("Reachability service failed:", exc_info=True)
logger.debug("See detailed traceback below:", exc_info=True)
if not ready.done(): if not ready.done():
ready.set_exception(e) ready.set_exception(e)

@ -50,18 +50,19 @@ class Server:
initial_peers: List[str], initial_peers: List[str],
dht_prefix: Optional[str], dht_prefix: Optional[str],
converted_model_name_or_path: str, converted_model_name_or_path: str,
public_name: Optional[str] = None,
throughput: Union[float, str], throughput: Union[float, str],
num_blocks: Optional[int] = None, num_blocks: Optional[int] = None,
block_indices: Optional[str] = None, block_indices: Optional[str] = None,
num_handlers: int = 8, num_handlers: int = 8,
inference_max_length: Optional[int] = None,
min_batch_size: int = 1, min_batch_size: int = 1,
max_batch_size: int = 2048, max_batch_size: Optional[int] = None,
inference_max_length: int = 2048, attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto", torch_dtype: str = "auto",
revision: Optional[str] = None, revision: Optional[str] = None,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None, max_disk_space: Optional[int] = None,
attn_cache_tokens: int = 8192,
alloc_timeout: float = 5, alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE, compression=CompressionType.NONE,
@ -93,8 +94,6 @@ class Server:
self.converted_model_name_or_path = converted_model_name_or_path self.converted_model_name_or_path = converted_model_name_or_path
self.num_handlers = num_handlers self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.compression = compression self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
@ -177,8 +176,19 @@ class Server:
self.quant_type = quant_type self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format") logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
is_multiquery_attn = self.block_config.num_key_value_groups > 1
if max_batch_size is None:
max_batch_size = 8192 if is_multiquery_attn else 2048
if inference_max_length is None:
inference_max_length = 8192 if is_multiquery_attn else 2048
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
# For attention cache in GPU or RAM # For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 2048
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8 self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
# For disk cache # For disk cache
@ -222,8 +232,9 @@ class Server:
throughput_info = {"throughput": throughput} throughput_info = {"throughput": throughput}
self.server_info = ServerInfo( self.server_info = ServerInfo(
state=ServerState.JOINING, state=ServerState.JOINING,
adapters=tuple(adapters), public_name=public_name,
version=petals.__version__, version=petals.__version__,
adapters=tuple(adapters),
torch_dtype=str(torch_dtype).replace("torch.", ""), torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(), quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode, using_relay=self.dht.client_mode,
@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread):
self.dht = dht self.dht = dht
self.server_info = server_info self.server_info = server_info
self.memory_cache = memory_cache self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8 self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.bytes_per_token //= block_config.num_key_value_groups
self.update_period = update_period self.update_period = update_period
self.expiration = expiration self.expiration = expiration
self.trigger = threading.Event() self.trigger = threading.Event()

@ -1,8 +1,12 @@
import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Type from typing import Optional, Type, Union
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from petals.utils.hf_auth import always_needs_auth
@dataclass @dataclass
class _ModelClasses: class _ModelClasses:
@ -26,8 +30,11 @@ class _AutoDistributedBase:
_mapping_field = None # Should be defined in child classes _mapping_field = None # Should be defined in child classes
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig: def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
config = AutoConfig.from_pretrained(*args, **kwargs) if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
kwargs["token"] = True
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
if config.model_type not in _CLASS_MAPPING: if config.model_type not in _CLASS_MAPPING:
raise ValueError(f"Petals does not support model type {config.model_type}") raise ValueError(f"Petals does not support model type {config.model_type}")
@ -35,7 +42,7 @@ class _AutoDistributedBase:
if proper_cls is None: if proper_cls is None:
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}") raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
return proper_cls.from_pretrained(*args, **kwargs) return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
class AutoDistributedConfig(_AutoDistributedBase): class AutoDistributedConfig(_AutoDistributedBase):

@ -2,6 +2,7 @@
Tools for converting transformer blocks, applying quantization and/or tensor parallelism Tools for converting transformer blocks, applying quantization and/or tensor parallelism
""" """
import re import re
from enum import Enum
from typing import Optional, Sequence from typing import Optional, Sequence
import tensor_parallel as tp import tensor_parallel as tp
@ -11,12 +12,16 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig from transformers import PretrainedConfig
from petals.utils.misc import QuantType
use_hivemind_log_handler("in_root_logger") use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__) logger = get_logger(__name__)
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
def convert_block( def convert_block(
block: nn.Module, block: nn.Module,
block_index: int, block_index: int,

@ -0,0 +1,7 @@
import os
from typing import Union
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")

@ -1,14 +1,5 @@
from enum import Enum
import torch import torch
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters

@ -17,8 +17,8 @@ from safetensors.torch import load_file
from transformers.utils import get_file_from_repo from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType
logger = get_logger(__name__) logger = get_logger(__name__)

Loading…
Cancel
Save